Commit 13a3f1d6 authored by Baber's avatar Baber
Browse files

better error handling

parent 7489a342
......@@ -33,21 +33,22 @@ class JudgeTask(ConfigurableTask):
resps = []
# load json
with open(self.output_path, "r") as f:
for line in f:
resp = json.loads(line)
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
resps.sort(key=lambda x: x["doc"])
# TODO: add filter name to resps
resps = resps[::2]
self.dataset["test"] = self.dataset["test"].add_column(
"resp", [resp["resp"] for resp in resps]
)
self.dataset["train"] = self.dataset["train"].add_column(
"resp", self.dataset["train"]["answer"]
)
print("resp columns added")
if self.output_path is not None:
with open(self.output_path, "r") as f:
for line in f:
resp = json.loads(line)
resps.append({"resp": resp["resps"][0][0], "doc": resp["doc_id"]})
resps.sort(key=lambda x: x["doc"])
# TODO: add filter name to resps
resps = resps[::2]
self.dataset["test"] = self.dataset["test"].add_column(
"resp", [resp["resp"] for resp in resps]
)
self.dataset["train"] = self.dataset["train"].add_column(
"resp", self.dataset["train"]["answer"]
)
print("resp columns added")
# def process_docs(self, dataset: datasets.Dataset):
# resps = []
......
......@@ -265,25 +265,20 @@ class TaskManager:
),
**config,
}
if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
if "output_type" in config:
task_object = JudgeTask(
config=config, output_path=config.get("output_path", None)
)
# if self._config_is_python_task(config):
# if self._class_has_config_in_constructor(config["class"]):
# task_object = config["class"](config=config)
# else:
# task_object = config["class"]()
# if isinstance(task_object, ConfigurableTask):
# # very scuffed: set task name here. TODO: fixme?
# task_object.config.task = config["task"]
else:
try:
if "output_type" in config:
task_object = JudgeTask(
config=config, output_path=config.get("output_path")
)
except Exception:
config.pop("output_type")
task_object = ConfigurableTask(config=config)
else:
task_object = ConfigurableTask(config=config)
task_object = ConfigurableTask(config=config)
return {task: task_object}
......@@ -502,7 +497,7 @@ class TaskManager:
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
"`tag` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations."
"which will be the official way to create groups with addition of group-wide configurations."
)
print_info = False
# attr = "tag"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment