Commit 7e186d53 authored by lintangsutawika's avatar lintangsutawika
Browse files

add method to have 1 yaml with multiple tasks

parent 88486e57
...@@ -128,6 +128,13 @@ class TaskManager: ...@@ -128,6 +128,13 @@ class TaskManager:
raise ValueError raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
def _get_idx(self, name):
if name not in self.task_index:
raise ValueError
if "idx" not in self.task_index[name]:
return None
return self.task_index[name]["idx"]
def _get_config(self, name): def _get_config(self, name):
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
...@@ -135,6 +142,9 @@ class TaskManager: ...@@ -135,6 +142,9 @@ class TaskManager:
if yaml_path == -1: if yaml_path == -1:
return {} return {}
else: else:
idx = self._get_idx(name)
if idx is not None:
return utils.load_yaml_config(yaml_path, mode="full")[idx]
return utils.load_yaml_config(yaml_path, mode="full") return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name): def _get_tasklist(self, name):
...@@ -373,7 +383,14 @@ class TaskManager: ...@@ -373,7 +383,14 @@ class TaskManager:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
config = utils.load_yaml_config(yaml_path, mode="simple") config_list = utils.load_yaml_config(yaml_path, mode="simple")
use_idx = False
if isinstance(config_list, dict):
config_list = [config_list]
else:
use_idx = True
for idx, config in enumerate(config_list):
if self._config_is_python_task(config): if self._config_is_python_task(config):
# This is a python class config # This is a python class config
tasks_and_groups[config["task"]] = { tasks_and_groups[config["task"]] = {
...@@ -409,6 +426,12 @@ class TaskManager: ...@@ -409,6 +426,12 @@ class TaskManager:
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
if use_idx:
tasks_and_groups[task] = {
**tasks_and_groups[task],
**{"idx": idx},
}
# TODO: remove group in next release # TODO: remove group in next release
for attr in ["tag", "group"]: for attr in ["tag", "group"]:
if attr in config: if attr in config:
......
...@@ -431,6 +431,9 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -431,6 +431,9 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
if isinstance(yaml_config, list):
return [load_yaml_config(yaml_path, config) for config in yaml_config]
if yaml_dir is None: if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path) yaml_dir = os.path.dirname(yaml_path)
......
# Arc Easy 2-shots
- task: arc_easy-2-shot
include: lm_eval/tasks/arc/arc_easy.yaml
num_fewshot: 2
# Arc Easy 4-shots
- task: arc_easy-4-shot
include: lm_eval/tasks/arc/arc_easy.yaml
num_fewshot: 4
# Arc Easy 8-shots
- task: arc_easy-8-shot
include: lm_eval/tasks/arc/arc_easy.yaml
num_fewshot: 8
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment