Unverified Commit d714fc95 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Faster Task and Group Loading, Allow Recursive Groups (#1321)



* add trust_remote_code as default

* task for testing recursive

* changed source of ALL_TASKS

* tasks should only accept TaskObjects

* initialize_tasks returns list of tasks and groups

* remove trust_remote_code for now

* moved constructor process to inside load_yaml_config

* more comprehensive way to index tasks and groups

* pre-commit format

* add exit after error

* adjust how task objects are called

* no need to use get_task_dict

* load_task_or_group works but only for tasks

* pre-commit format

* half working for nested groups

* changed variable names

* allow groups and tasks to work

* temp save

* indexing and loading are part of a task_manager object

* adapted initialize_tasks

* iron out bugs

* fixed typo

* fixed typo

* simplified code

* further tidy up

* remove lines for testing

* removed test lines

* removed unused code

* remove unused import

* fixed bug

* removed comments

* group in a list of group can accept parameter changes like `num_fewshot`

* add trust_remote_code as default

* task for testing recursive

* changed source of ALL_TASKS

* tasks should only accept TaskObjects

* initialize_tasks returns list of tasks and groups

* remove trust_remote_code for now

* moved constructor process to inside load_yaml_config

* more comprehensive way to index tasks and groups

* pre-commit format

* add exit after error

* adjust how task objects are called

* no need to use get_task_dict

* load_task_or_group works but only for tasks

* pre-commit format

* half working for nested groups

* changed variable names

* allow groups and tasks to work

* temp save

* indexing and loading are part of a task_manager object

* adapted initialize_tasks

* iron out bugs

* fixed typo

* fixed typo

* simplified code

* further tidy up

* remove lines for testing

* removed test lines

* removed unused code

* remove unused import

* fixed bug

* removed comments

* group in a list of group can accept parameter changes like `num_fewshot`

* check if config is task update

* add GroupConfig object

* edit test yaml

* remove args

* testing returning to python task list

* add weight_by_size config

* describe weight_by_size in docs

* fix weight by size potential error

* can load individual custom python class task

* moved import_function into the config loading file

* remove print lines

* add squadv2 yaml

* temporary scroll implementation

* revert back to use load_yaml_config but with modes

* fix group being loaded with a None

* reformat

* can load unregistered tasks from a group

* update scrolls

* edit scrolls multiplechoice task

* adjust class initialization

* fix initialization

* changed how to identify group and python tasks, fix logger

* allow loading "include" that is nested in a group config

* reworked flan benchmark

* allow duplicate task in the same group to co-exist

* process group_alias

* removed group_alias

* allow parameters set in group_config to apply to all tasks in tasklist

* add function, but comment for now

* reworked processing dict-base config

* fixed how configs in group are processed

* update to allow root group to have its alias used

* remove unused classes

* remove unused classes

* revert some parts to original

* forgot to change one variable

* adapt the new process to use get_task_dict

* fix for singular group call

* fix variable names

* add TaskManager into the evaluator

* format

* changed how dict tasks are loaded

* add docs

* Update docs/new_task_guide.md
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update evaluator.py

* Update evaluator.py

* remove groupconfig for now

* changed _config to config

* update interface.md to explain TaskManager

* added property functions

* adjusted logger

* update write_out.py

* updated tests

* added documentation and some modifications

* added docstring documentation

* precommit format

* updated task loading for tests

* updates tests

* changed arg order for load_yaml_config

* update to handle scrolls and edit log message

* remove unused lines

* return a list of task classes and not a dict

* Update __init__.py

* Delete lm_eval/tasks/benchmarks/test.yaml

* Update task.py

* Update lm_eval/utils.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/utils.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update utils.py

* re-added old functions with new log message

* Update docs/new_task_guide.md
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update new_task_guide.md

* added infor regarding `get_task_dict` and documentation

* add get_config for Task

* pre-commit formatting

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 17191063
...@@ -5,19 +5,13 @@ task: ...@@ -5,19 +5,13 @@ task:
- medqa_4options - medqa_4options
- task: mmlu_anatomy - task: mmlu_anatomy
task_alias: "anatomy (mmlu)" task_alias: "anatomy (mmlu)"
group_alias: null
- task: mmlu_clinical_knowledge - task: mmlu_clinical_knowledge
task_alias: "clinical_knowledge (mmlu)" task_alias: "clinical_knowledge (mmlu)"
group_alias: null
- task: mmlu_college_medicine - task: mmlu_college_medicine
task_alias: "college_medicine (mmlu)" task_alias: "college_medicine (mmlu)"
group_alias: null
- task: mmlu_medical_genetics - task: mmlu_medical_genetics
task_alias: "medical_genetics (mmlu)" task_alias: "medical_genetics (mmlu)"
group_alias: null
- task: mmlu_professional_medicine - task: mmlu_professional_medicine
task_alias: "professional_medicine (mmlu)" task_alias: "professional_medicine (mmlu)"
group_alias: null
- task: mmlu_college_biology - task: mmlu_college_biology
task_alias: "college_biology (mmlu)" task_alias: "college_biology (mmlu)"
group_alias: null
group: scrolls group: scrolls
task: task:
- scrolls_qasper - task: scrolls_qasper
- scrolls_quality class: !function task.Qasper
- scrolls_narrativeqa - task: scrolls_quality
- scrolls_contractnli class: !function task.QuALITY
- scrolls_govreport - task: scrolls_narrativeqa
- scrolls_summscreenfd class: !function task.NarrativeQA
- scrolls_qmsum - task: scrolls_contractnli
class: !function task.ContractNLI
- task: scrolls_govreport
class: !function task.GovReport
- task: scrolls_summscreenfd
class: !function task.SummScreenFD
- task: scrolls_qmsum
class: !function task.QMSum
...@@ -115,8 +115,10 @@ class _SCROLLSTask(Task): ...@@ -115,8 +115,10 @@ class _SCROLLSTask(Task):
PRUNE_MAX_TOKENS = None PRUNE_MAX_TOKENS = None
PRUNE_NUM_PROC = None PRUNE_NUM_PROC = None
def __post_init__(self): def __init__(self):
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME) super().__init__()
if self.DATASET_NAME is not None:
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -224,9 +226,10 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -224,9 +226,10 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["gold"] gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0 lls, _ = zip(*results)
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]]) completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
return { return {
"acc": acc, "acc": acc,
...@@ -279,7 +282,6 @@ class _SCROLLSSummaryTask(_SCROLLSTask): ...@@ -279,7 +282,6 @@ class _SCROLLSSummaryTask(_SCROLLSTask):
return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:" return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:"
@register_task("scrolls_qasper")
class Qasper(_SCROLLSTask): class Qasper(_SCROLLSTask):
"""A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers """A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
https://arxiv.org/abs/2105.03011 https://arxiv.org/abs/2105.03011
...@@ -337,7 +339,6 @@ class Qasper(_SCROLLSTask): ...@@ -337,7 +339,6 @@ class Qasper(_SCROLLSTask):
) )
@register_task("scrolls_quality")
class QuALITY(_SCROLLSMultipleChoiceTask): class QuALITY(_SCROLLSMultipleChoiceTask):
"""QuALITY: Question Answering with Long Input Texts, Yes! """QuALITY: Question Answering with Long Input Texts, Yes!
https://arxiv.org/abs/2112.08608 https://arxiv.org/abs/2112.08608
...@@ -366,7 +367,6 @@ class QuALITY(_SCROLLSMultipleChoiceTask): ...@@ -366,7 +367,6 @@ class QuALITY(_SCROLLSMultipleChoiceTask):
return [doc] return [doc]
@register_task("scrolls_narrativeqa")
class NarrativeQA(_SCROLLSTask): class NarrativeQA(_SCROLLSTask):
"""The NarrativeQA Reading Comprehension Challenge """The NarrativeQA Reading Comprehension Challenge
https://arxiv.org/abs/1712.07040 https://arxiv.org/abs/1712.07040
...@@ -400,7 +400,6 @@ class NarrativeQA(_SCROLLSTask): ...@@ -400,7 +400,6 @@ class NarrativeQA(_SCROLLSTask):
) )
@register_task("scrolls_contractnli")
class ContractNLI(_SCROLLSMultipleChoiceTask): class ContractNLI(_SCROLLSMultipleChoiceTask):
"""ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts """ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
https://arxiv.org/abs/1712.07040 https://arxiv.org/abs/1712.07040
...@@ -419,7 +418,6 @@ class ContractNLI(_SCROLLSMultipleChoiceTask): ...@@ -419,7 +418,6 @@ class ContractNLI(_SCROLLSMultipleChoiceTask):
return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:" return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:"
@register_task("scrolls_govreport")
class GovReport(_SCROLLSSummaryTask): class GovReport(_SCROLLSSummaryTask):
"""Efficient Attentions for Long Document Summarization """Efficient Attentions for Long Document Summarization
https://arxiv.org/abs/2104.02112 https://arxiv.org/abs/2104.02112
...@@ -433,7 +431,6 @@ class GovReport(_SCROLLSSummaryTask): ...@@ -433,7 +431,6 @@ class GovReport(_SCROLLSSummaryTask):
DATASET_NAME = "gov_report" DATASET_NAME = "gov_report"
@register_task("scrolls_summscreenfd")
class SummScreenFD(_SCROLLSSummaryTask): class SummScreenFD(_SCROLLSSummaryTask):
"""SummScreen: A Dataset for Abstractive Screenplay Summarization """SummScreen: A Dataset for Abstractive Screenplay Summarization
https://arxiv.org/abs/2104.07091 https://arxiv.org/abs/2104.07091
...@@ -442,7 +439,6 @@ class SummScreenFD(_SCROLLSSummaryTask): ...@@ -442,7 +439,6 @@ class SummScreenFD(_SCROLLSSummaryTask):
DATASET_NAME = "summ_screen_fd" DATASET_NAME = "summ_screen_fd"
@register_task("scrolls_qmsum")
class QMSum(_SCROLLSSummaryTask): class QMSum(_SCROLLSSummaryTask):
"""QMSum: A New Benchmark for Query-based Multi-domain """QMSum: A New Benchmark for Query-based Multi-domain
Meeting Summarization Meeting Summarization
......
task: squadv2
class: !function task.SQuAD2
...@@ -21,7 +21,6 @@ from packaging import version ...@@ -21,7 +21,6 @@ from packaging import version
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
_CITATION = """ _CITATION = """
@misc{rajpurkar2018know, @misc{rajpurkar2018know,
...@@ -47,7 +46,6 @@ def _squad_agg(key, items): ...@@ -47,7 +46,6 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references).get(key, 0) return _squad_metric(predictions=predictions, references=references).get(key, 0)
@register_task("squadv2")
class SQuAD2(Task): class SQuAD2(Task):
VERSION = 3 VERSION = 3
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
......
...@@ -472,6 +472,10 @@ def get_git_commit_hash(): ...@@ -472,6 +472,10 @@ def get_git_commit_hash():
return git_hash return git_hash
def ignore_constructor(loader, node):
return node
def import_function(loader, node): def import_function(loader, node):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
...@@ -489,11 +493,14 @@ def import_function(loader, node): ...@@ -489,11 +493,14 @@ def import_function(loader, node):
return function return function
# Add the import_function constructor to the YAML loader def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
yaml.add_constructor("!function", import_function) if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
constructor_fn = import_function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): # Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn)
if yaml_config is None: if yaml_config is None:
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
...@@ -521,7 +528,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -521,7 +528,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
path = os.path.join(yaml_dir, path) path = os.path.join(yaml_dir, path)
try: try:
included_yaml_config = load_yaml_config(path) included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
final_yaml_config.update(included_yaml_config) final_yaml_config.update(included_yaml_config)
except Exception as ex: except Exception as ex:
# If failed to load, ignore # If failed to load, ignore
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import numpy as np import numpy as np
from lm_eval import tasks from lm_eval import tasks
from lm_eval.tasks import include_path, initialize_tasks from lm_eval.tasks import TaskManager
from lm_eval.utils import eval_logger, join_iters from lm_eval.utils import eval_logger, join_iters
...@@ -39,22 +39,21 @@ def main(): ...@@ -39,22 +39,21 @@ def main():
args = parse_args() args = parse_args()
np.random.seed(args.seed) np.random.seed(args.seed)
initialize_tasks(args.verbosity)
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = task_manager.all_tasks
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names, task_manager)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if isinstance(task, tuple): if isinstance(task, tuple):
group_name, task = task _, task = task
rnd = random.Random() rnd = random.Random()
rnd.seed(args.seed) rnd.seed(args.seed)
......
...@@ -11,20 +11,21 @@ from lm_eval.api.instance import Instance ...@@ -11,20 +11,21 @@ from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM from lm_eval.models.huggingface import HFLM
tasks.initialize_tasks() task_manager = tasks.TaskManager()
class Test_HFLM: class Test_HFLM:
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
task_list = task_manager.load_task_or_group(["arc_easy", "gsm8k", "wikitext"])
version_minor = sys.version_info.minor version_minor = sys.version_info.minor
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore multiple_choice_task = task_list["arc_easy"] # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")() # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: list[Instance] = generate_until_task.instances generate_until: list[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: list[Instance] = rolling_task.instances ROLLING: list[Instance] = rolling_task.instances
......
...@@ -6,12 +6,9 @@ from optimum.intel import OVModelForCausalLM ...@@ -6,12 +6,9 @@ from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import lm_eval.tasks as tasks
from lm_eval.api.registry import get_model from lm_eval.api.registry import get_model
tasks.initialize_tasks()
SUPPORTED_ARCHITECTURES_TASKS = { SUPPORTED_ARCHITECTURES_TASKS = {
"facebook/opt-125m": "lambada_openai", "facebook/opt-125m": "lambada_openai",
"hf-internal-testing/tiny-random-gpt2": "wikitext", "hf-internal-testing/tiny-random-gpt2": "wikitext",
......
...@@ -7,6 +7,9 @@ import lm_eval.tasks as tasks ...@@ -7,6 +7,9 @@ import lm_eval.tasks as tasks
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
task_manager = tasks.TaskManager()
@pytest.mark.skip(reason="requires CUDA") @pytest.mark.skip(reason="requires CUDA")
class TEST_VLLM: class TEST_VLLM:
vllm = pytest.importorskip("vllm") vllm = pytest.importorskip("vllm")
...@@ -17,15 +20,15 @@ class TEST_VLLM: ...@@ -17,15 +20,15 @@ class TEST_VLLM:
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
tasks.initialize_tasks() task_list = task_manager.load_task_or_group(["arc_easy", "gsm8k", "wikitext"])
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore multiple_choice_task = task_list["arc_easy"] # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: List[Instance] = multiple_choice_task.instances MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")() # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: List[Instance] = generate_until_task.instances generate_until: List[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: List[Instance] = rolling_task.instances ROLLING: List[Instance] = rolling_task.instances
......
...@@ -6,11 +6,9 @@ import pytest ...@@ -6,11 +6,9 @@ import pytest
# import lm_eval.models as models # import lm_eval.models as models
import lm_eval.api as api import lm_eval.api as api
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import lm_eval.tasks as tasks from lm_eval import tasks
tasks.initialize_tasks()
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
...@@ -46,7 +44,8 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str ...@@ -46,7 +44,8 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str
"device": None, "device": None,
}, },
) )
task_dict = tasks.get_task_dict(task_name, num_fewshot=0) task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
......
...@@ -8,7 +8,7 @@ from lm_eval.api.task import ConfigurableTask ...@@ -8,7 +8,7 @@ from lm_eval.api.task import ConfigurableTask
from .utils import new_tasks from .utils import new_tasks
tasks.initialize_tasks() task_manager = tasks.TaskManager()
# Default Task # Default Task
TASKS = ["arc_easy"] TASKS = ["arc_easy"]
...@@ -19,9 +19,9 @@ def task_class(): ...@@ -19,9 +19,9 @@ def task_class():
task_classes = new_tasks() task_classes = new_tasks()
# Check if task_classes is empty # Check if task_classes is empty
if task_classes: if task_classes:
return [tasks.TASK_REGISTRY.get(x)() for x in task_classes] return list(task_manager.load_task_or_group(task_classes).values())
else: else:
return [tasks.TASK_REGISTRY.get(x)() for x in TASKS] return list(task_manager.load_task_or_group(TASKS).values())
@pytest.fixture() @pytest.fixture()
......
import os import os
from pathlib import Path
from typing import List, Union from typing import List, Union
from lm_eval.utils import load_yaml_config from lm_eval.utils import load_yaml_config
...@@ -20,17 +19,18 @@ def load_changed_files(file_path: str) -> List[str]: ...@@ -20,17 +19,18 @@ def load_changed_files(file_path: str) -> List[str]:
# checks the txt file for list of changed files. # checks the txt file for list of changed files.
# if file ends with .yaml then check yaml for task name # if file ends with .yaml then check yaml and load the config.
# if file ends with .py then parse the folder for all yaml files # if the config task is a string, it's a task config.
# skips benchmarks folder # if the config task is a list, it's a group config.
def parser(full_path: List[str]) -> List[str]: def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml") and "benchmarks" not in x: if os.path.exists(x) and x.endswith(".yaml"):
_output.add(load_yaml_config(x)["task"]) config = load_yaml_config(x, mode="simple")
elif x.endswith(".py") and "benchmarks" not in x: if isinstance(config["task"], str):
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))] _output.add(config["task"])
_output |= {load_yaml_config(x)["task"] for x in path} elif isinstance(config["task"], list):
_output.add(config["group"])
return list(_output) return list(_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment