Commit 2275b1c4 authored by Baber's avatar Baber
Browse files

add resps to `download`

parent 30430b60
import json
from typing import Any, Dict, Optional
import datasets
......@@ -6,17 +7,53 @@ from lm_eval.api.task import ConfigurableTask
class JudgeTask(ConfigurableTask):
def __init__(self, config, output_path):
super().__init__(config)
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
output_path: Optional[str] = None,
**kwargs,
) -> None:
self._config = config
# self.config["process_docs"] = self.process_docs
self.output_path = output_path
super().__init__(
data_dir=None, cache_dir=None, download_mode=None, config=self.config
)
# self.output_path = output_path
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def process_docs(self, dataset: datasets.Dataset):
resps = []
# load json
with open(self.output_path, "r") as f:
resp = json.load(f)
resps.append({"resp": resp["resps"], "doc": resp["doc"]})
for line in f:
resp = json.loads(line)
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
resps.sort(key=lambda x: x["doc"])
dataset.add_column("resp", resps)
return resps
# TODO: add filter name to resps
resps = resps[::2]
self.dataset["test"] = self.dataset["test"].add_column(
"resp", [resp["resp"] for resp in resps]
)
print("done")
# def process_docs(self, dataset: datasets.Dataset):
# resps = []
# # load json
# with open(self.output_path, "r") as f:
# for line in f:
# resp = json.loads(line)
# resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
#
# resps.sort(key=lambda x: x["doc"])
# dataset.add_column("resp", resps)
# return resps
......@@ -53,6 +53,7 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass
class TaskConfig(dict):
output_path: Optional[str] = None
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
......
......@@ -274,9 +274,9 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
else:
if "resp_to_doc" in config:
if "output_type" in config:
task_object = JudgeTask(
config=config, output_path=config.get("output_path", None)
config=config, output_path=config.get("output_path")
)
else:
task_object = ConfigurableTask(config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment