Commit 13a3f1d6 authored by Baber's avatar Baber
Browse files

better error handling

parent 7489a342
...@@ -33,21 +33,22 @@ class JudgeTask(ConfigurableTask): ...@@ -33,21 +33,22 @@ class JudgeTask(ConfigurableTask):
resps = [] resps = []
# load json # load json
with open(self.output_path, "r") as f: if self.output_path is not None:
for line in f: with open(self.output_path, "r") as f:
resp = json.loads(line) for line in f:
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]}) resp = json.loads(line)
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
resps.sort(key=lambda x: x["doc"])
# TODO: add filter name to resps resps.sort(key=lambda x: x["doc"])
resps = resps[::2] # TODO: add filter name to resps
self.dataset["test"] = self.dataset["test"].add_column( resps = resps[::2]
"resp", [resp["resp"] for resp in resps] self.dataset["test"] = self.dataset["test"].add_column(
) "resp", [resp["resp"] for resp in resps]
self.dataset["train"] = self.dataset["train"].add_column( )
"resp", self.dataset["train"]["answer"] self.dataset["train"] = self.dataset["train"].add_column(
) "resp", self.dataset["train"]["answer"]
print("resp columns added") )
print("resp columns added")
# def process_docs(self, dataset: datasets.Dataset): # def process_docs(self, dataset: datasets.Dataset):
# resps = [] # resps = []
......
...@@ -265,25 +265,20 @@ class TaskManager: ...@@ -265,25 +265,20 @@ class TaskManager:
), ),
**config, **config,
} }
if self._config_is_python_task(config): if "output_type" in config:
if self._class_has_config_in_constructor(config["class"]): task_object = JudgeTask(
task_object = config["class"](config=config) config=config, output_path=config.get("output_path", None)
else: )
task_object = config["class"]() # if self._config_is_python_task(config):
if isinstance(task_object, ConfigurableTask): # if self._class_has_config_in_constructor(config["class"]):
# very scuffed: set task name here. TODO: fixme? # task_object = config["class"](config=config)
task_object.config.task = config["task"] # else:
# task_object = config["class"]()
# if isinstance(task_object, ConfigurableTask):
# # very scuffed: set task name here. TODO: fixme?
# task_object.config.task = config["task"]
else: else:
try: task_object = ConfigurableTask(config=config)
if "output_type" in config:
task_object = JudgeTask(
config=config, output_path=config.get("output_path")
)
except Exception:
config.pop("output_type")
task_object = ConfigurableTask(config=config)
else:
task_object = ConfigurableTask(config=config)
return {task: task_object} return {task: task_object}
...@@ -502,7 +497,7 @@ class TaskManager: ...@@ -502,7 +497,7 @@ class TaskManager:
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. " "`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
"`tag` will be used to allow to call a collection of tasks just like `group`. " "`tag` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup " "`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations." "which will be the official way to create groups with addition of group-wide configurations."
) )
print_info = False print_info = False
# attr = "tag" # attr = "tag"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment