Commit dba96127 authored by Baber's avatar Baber
Browse files

add metadata to TaskManager

parent fabd0d90
......@@ -317,7 +317,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
metadata = (
simple_parse_args_string(args.model_args)
if isinstance(args.model_args, str)
else {} | parse_keyed_list_string(args.metadata)
)
task_manager = TaskManager(
args.verbosity, include_path=args.include_path, metadata=metadata
)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning(
......
......@@ -931,6 +931,9 @@ class ConfigurableTask(Task):
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
if isinstance(self.config.download_dataset, Callable):
eval_logger.warning(
f"Custom kwargs used for the {self.config.task} can be passed to `--metadata` in console or to the TaskManager. For example --metadata=max_seq_lengths=4096,8192. For details see task Readme."
)
self.dataset = self.config.download_dataset(
**self.config.metadata,
**self.config.dataset_kwargs
......
......@@ -232,17 +232,16 @@ def simple_evaluate(
)
if task_manager is None:
task_manager = TaskManager(verbosity)
metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
) | (metadata or {})
task_manager = TaskManager(verbosity, metadata=metadata)
task_dict = get_task_dict(
tasks,
task_manager,
metadata=(
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
)
| (metadata or {}),
)
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
......@@ -25,11 +25,13 @@ class TaskManager:
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))
self.metadata = metadata
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
......@@ -257,7 +259,6 @@ class TaskManager:
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
metadata: Optional[dict] = None,
) -> Mapping:
def _load_task(config, task):
if "include" in config:
......@@ -279,8 +280,8 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = task
else:
if metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
else:
config["metadata"] = config.get("metadata", {})
task_object = ConfigurableTask(config=config)
......@@ -290,8 +291,8 @@ class TaskManager:
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if metadata is not None:
config["metadata"] = config.get("metadata", {}) | metadata
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
......@@ -403,7 +404,6 @@ class TaskManager:
fn = partial(
self._load_individual_task_or_group,
metadata=metadata,
parent_name=group_name,
update_config=update_config,
)
......@@ -411,9 +411,7 @@ class TaskManager:
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
}
def load_task_or_group(
self, task_list: Optional[Union[str, list]] = None, metadata=None
) -> dict:
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
......@@ -428,17 +426,15 @@ class TaskManager:
all_loaded_tasks = dict(
collections.ChainMap(
*map(
lambda task: self._load_individual_task_or_group(
task, metadata=metadata
),
lambda task: self._load_individual_task_or_group(task),
task_list,
)
)
)
return all_loaded_tasks
def load_config(self, config: Dict, metadata=Optional[dict]):
return self._load_individual_task_or_group(config, metadata=metadata)
def load_config(self, config: Dict):
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata,
......@@ -598,7 +594,6 @@ def _check_duplicates(task_dict: dict) -> None:
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
metadata: dict = None,
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
......@@ -639,7 +634,7 @@ def get_task_dict(
task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list, metadata=metadata
string_task_name_list
)
for task_element in others_task_name_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment