Commit ec05e561 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'recursive-groups' of...

Merge branch 'recursive-groups' of https://github.com/EleutherAI/lm-evaluation-harness into t5v2-alt-plus
parents 74857aa7 be5472a9
...@@ -40,6 +40,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -40,6 +40,7 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig(dict):
group: str = None group: str = None
......
...@@ -165,10 +165,9 @@ class TaskManager(abc.ABC): ...@@ -165,10 +165,9 @@ class TaskManager(abc.ABC):
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
if (self._name_is_registered(group_name) is False) or (self._get_yaml_path(group_name) == -1): all_subtasks = {}
if (parent_name is not None) and ((self._name_is_registered(group_name) is False) or (self._get_yaml_path(group_name) == -1)):
all_subtasks = {group_name: (parent_name, None)} all_subtasks = {group_name: (parent_name, None)}
else:
all_subtasks = {}
fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config) fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config)
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))} all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
......
task: squadv2 task: squadv2
class: !function task.SQuAD2 class: !function task.SQuAD2
\ No newline at end of file
...@@ -483,6 +483,7 @@ def get_git_commit_hash(): ...@@ -483,6 +483,7 @@ def get_git_commit_hash():
def ignore_constructor(loader, node): def ignore_constructor(loader, node):
return node return node
def import_function(loader, node): def import_function(loader, node):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
...@@ -490,9 +491,7 @@ def import_function(loader, node): ...@@ -490,9 +491,7 @@ def import_function(loader, node):
*module_name, function_name = function_name.split(".") *module_name, function_name = function_name.split(".")
if isinstance(module_name, list): if isinstance(module_name, list):
module_name = ".".join(module_name) module_name = ".".join(module_name)
module_path = os.path.normpath( module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
os.path.join(yaml_path, "{}.py".format(module_name))
)
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
...@@ -503,7 +502,6 @@ def import_function(loader, node): ...@@ -503,7 +502,6 @@ def import_function(loader, node):
def load_yaml_config(mode="simple", yaml_path=None, yaml_config=None, yaml_dir=None): def load_yaml_config(mode="simple", yaml_path=None, yaml_config=None, yaml_dir=None):
if mode == "simple": if mode == "simple":
constuctor_fn = ignore_constructor constuctor_fn = ignore_constructor
elif mode == "full": elif mode == "full":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment