Commit 2275b1c4 authored by Baber's avatar Baber
Browse files

add resps to `download`

parent 30430b60
import json import json
from typing import Any, Dict, Optional
import datasets import datasets
...@@ -6,17 +7,53 @@ from lm_eval.api.task import ConfigurableTask ...@@ -6,17 +7,53 @@ from lm_eval.api.task import ConfigurableTask
class JudgeTask(ConfigurableTask): class JudgeTask(ConfigurableTask):
def __init__(self, config, output_path): def __init__(
super().__init__(config) self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
output_path: Optional[str] = None,
**kwargs,
) -> None:
self._config = config
# self.config["process_docs"] = self.process_docs
self.output_path = output_path self.output_path = output_path
super().__init__(
data_dir=None, cache_dir=None, download_mode=None, config=self.config
)
# self.output_path = output_path
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def process_docs(self, dataset: datasets.Dataset):
resps = [] resps = []
# load json # load json
with open(self.output_path, "r") as f: with open(self.output_path, "r") as f:
resp = json.load(f) for line in f:
resps.append({"resp": resp["resps"], "doc": resp["doc"]}) resp = json.loads(line)
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
resps.sort(key=lambda x: x["doc"]) resps.sort(key=lambda x: x["doc"])
dataset.add_column("resp", resps) # TODO: add filter name to resps
return resps resps = resps[::2]
self.dataset["test"] = self.dataset["test"].add_column(
"resp", [resp["resp"] for resp in resps]
)
print("done")
# def process_docs(self, dataset: datasets.Dataset):
# resps = []
# # load json
# with open(self.output_path, "r") as f:
# for line in f:
# resp = json.loads(line)
# resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
#
# resps.sort(key=lambda x: x["doc"])
# dataset.add_column("resp", resps)
# return resps
...@@ -53,6 +53,7 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -53,6 +53,7 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
output_path: Optional[str] = None
# task naming/registry # task naming/registry
task: Optional[str] = None task: Optional[str] = None
task_alias: Optional[str] = None task_alias: Optional[str] = None
......
...@@ -274,9 +274,9 @@ class TaskManager: ...@@ -274,9 +274,9 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme? # very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"] task_object.config.task = config["task"]
else: else:
if "resp_to_doc" in config: if "output_type" in config:
task_object = JudgeTask( task_object = JudgeTask(
config=config, output_path=config.get("output_path", None) config=config, output_path=config.get("output_path")
) )
else: else:
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment