Commit dba96127 authored by Baber's avatar Baber
Browse files

add metadata to TaskManager

parent fabd0d90
...@@ -317,7 +317,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -317,7 +317,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path) metadata = (
simple_parse_args_string(args.model_args)
if isinstance(args.model_args, str)
else {} | parse_keyed_list_string(args.metadata)
)
task_manager = TaskManager(
args.verbosity, include_path=args.include_path, metadata=metadata
)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning( eval_logger.warning(
......
...@@ -931,6 +931,9 @@ class ConfigurableTask(Task): ...@@ -931,6 +931,9 @@ class ConfigurableTask(Task):
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None: ) -> None:
if isinstance(self.config.download_dataset, Callable): if isinstance(self.config.download_dataset, Callable):
eval_logger.warning(
f"Custom kwargs used for the {self.config.task} can be passed to `--metadata` in console or to the TaskManager. For example --metadata=max_seq_lengths=4096,8192. For details see task Readme."
)
self.dataset = self.config.download_dataset( self.dataset = self.config.download_dataset(
**self.config.metadata, **self.config.metadata,
**self.config.dataset_kwargs **self.config.dataset_kwargs
......
...@@ -232,17 +232,16 @@ def simple_evaluate( ...@@ -232,17 +232,16 @@ def simple_evaluate(
) )
if task_manager is None: if task_manager is None:
task_manager = TaskManager(verbosity) metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
) | (metadata or {})
task_manager = TaskManager(verbosity, metadata=metadata)
task_dict = get_task_dict( task_dict = get_task_dict(
tasks, tasks,
task_manager, task_manager,
metadata=(
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
)
| (metadata or {}),
) )
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
...@@ -25,11 +25,13 @@ class TaskManager: ...@@ -25,11 +25,13 @@ class TaskManager:
verbosity="INFO", verbosity="INFO",
include_path: Optional[Union[str, List]] = None, include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None: ) -> None:
self.verbosity = verbosity self.verbosity = verbosity
self.include_path = include_path self.include_path = include_path
self.logger = utils.eval_logger self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}")) self.logger.setLevel(getattr(logging, f"{verbosity}"))
self.metadata = metadata
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults include_path=include_path, include_defaults=include_defaults
...@@ -257,7 +259,6 @@ class TaskManager: ...@@ -257,7 +259,6 @@ class TaskManager:
name_or_config: Optional[Union[str, dict]] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
metadata: Optional[dict] = None,
) -> Mapping: ) -> Mapping:
def _load_task(config, task): def _load_task(config, task):
if "include" in config: if "include" in config:
...@@ -279,8 +280,8 @@ class TaskManager: ...@@ -279,8 +280,8 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme? # very scuffed: set task name here. TODO: fixme?
task_object.config.task = task task_object.config.task = task
else: else:
if metadata is not None: if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata config["metadata"] = config.get("metadata", {}) | self.metadata
else: else:
config["metadata"] = config.get("metadata", {}) config["metadata"] = config.get("metadata", {})
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
...@@ -290,8 +291,8 @@ class TaskManager: ...@@ -290,8 +291,8 @@ class TaskManager:
def _get_group_and_subtask_from_config( def _get_group_and_subtask_from_config(
config: dict, config: dict,
) -> tuple[ConfigurableGroup, list[str]]: ) -> tuple[ConfigurableGroup, list[str]]:
if metadata is not None: if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config) group_name = ConfigurableGroup(config=config)
subtask_list = [] subtask_list = []
for task in group_name.config["task"]: for task in group_name.config["task"]:
...@@ -403,7 +404,6 @@ class TaskManager: ...@@ -403,7 +404,6 @@ class TaskManager:
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
metadata=metadata,
parent_name=group_name, parent_name=group_name,
update_config=update_config, update_config=update_config,
) )
...@@ -411,9 +411,7 @@ class TaskManager: ...@@ -411,9 +411,7 @@ class TaskManager:
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
} }
def load_task_or_group( def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
self, task_list: Optional[Union[str, list]] = None, metadata=None
) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None :param task_list: Union[str, list] = None
...@@ -428,17 +426,15 @@ class TaskManager: ...@@ -428,17 +426,15 @@ class TaskManager:
all_loaded_tasks = dict( all_loaded_tasks = dict(
collections.ChainMap( collections.ChainMap(
*map( *map(
lambda task: self._load_individual_task_or_group( lambda task: self._load_individual_task_or_group(task),
task, metadata=metadata
),
task_list, task_list,
) )
) )
) )
return all_loaded_tasks return all_loaded_tasks
def load_config(self, config: Dict, metadata=Optional[dict]): def load_config(self, config: Dict):
return self._load_individual_task_or_group(config, metadata=metadata) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str): def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
...@@ -598,7 +594,6 @@ def _check_duplicates(task_dict: dict) -> None: ...@@ -598,7 +594,6 @@ def _check_duplicates(task_dict: dict) -> None:
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
metadata: dict = None,
): ):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
...@@ -639,7 +634,7 @@ def get_task_dict( ...@@ -639,7 +634,7 @@ def get_task_dict(
task_manager = TaskManager() task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group( task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list, metadata=metadata string_task_name_list
) )
for task_element in others_task_name_list: for task_element in others_task_name_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment