Unverified Commit 56e6487c authored by lmagne's avatar lmagne Committed by GitHub
Browse files

add dataset split and config to model-index in TrainingSummary.from_trainer (#18064)



* added metadata to training summary

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent fde22c75
...@@ -384,6 +384,7 @@ class TrainingSummary: ...@@ -384,6 +384,7 @@ class TrainingSummary:
dataset: Optional[Union[str, List[str]]] = None dataset: Optional[Union[str, List[str]]] = None
dataset_tags: Optional[Union[str, List[str]]] = None dataset_tags: Optional[Union[str, List[str]]] = None
dataset_args: Optional[Union[str, List[str]]] = None dataset_args: Optional[Union[str, List[str]]] = None
dataset_metadata: Optional[Dict[str, Any]] = None
eval_results: Optional[Dict[str, float]] = None eval_results: Optional[Dict[str, float]] = None
eval_lines: Optional[List[str]] = None eval_lines: Optional[List[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None hyperparameters: Optional[Dict[str, Any]] = None
...@@ -412,10 +413,12 @@ class TrainingSummary: ...@@ -412,10 +413,12 @@ class TrainingSummary:
dataset_names = _listify(self.dataset) dataset_names = _listify(self.dataset)
dataset_tags = _listify(self.dataset_tags) dataset_tags = _listify(self.dataset_tags)
dataset_args = _listify(self.dataset_args) dataset_args = _listify(self.dataset_args)
dataset_metadata = _listify(self.dataset_metadata)
if len(dataset_args) < len(dataset_tags): if len(dataset_args) < len(dataset_tags):
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args)) dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
dataset_mapping = {tag: name for tag, name in zip(dataset_tags, dataset_names)} dataset_mapping = {tag: name for tag, name in zip(dataset_tags, dataset_names)}
dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)} dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)}
dataset_metadata_mapping = {tag: metadata for tag, metadata in zip(dataset_tags, dataset_metadata)}
task_mapping = { task_mapping = {
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
...@@ -438,7 +441,12 @@ class TrainingSummary: ...@@ -438,7 +441,12 @@ class TrainingSummary:
result["task"] = {"name": task_mapping[task_tag], "type": task_tag} result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
if ds_tag is not None: if ds_tag is not None:
result["dataset"] = {"name": dataset_mapping[ds_tag], "type": ds_tag} metadata = dataset_metadata_mapping.get(ds_tag, {})
result["dataset"] = {
"name": dataset_mapping[ds_tag],
"type": ds_tag,
**metadata,
}
if dataset_arg_mapping[ds_tag] is not None: if dataset_arg_mapping[ds_tag] is not None:
result["dataset"]["args"] = dataset_arg_mapping[ds_tag] result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
...@@ -565,6 +573,7 @@ class TrainingSummary: ...@@ -565,6 +573,7 @@ class TrainingSummary:
finetuned_from=None, finetuned_from=None,
tasks=None, tasks=None,
dataset_tags=None, dataset_tags=None,
dataset_metadata=None,
dataset=None, dataset=None,
dataset_args=None, dataset_args=None,
): ):
...@@ -574,6 +583,8 @@ class TrainingSummary: ...@@ -574,6 +583,8 @@ class TrainingSummary:
default_tag = one_dataset.builder_name default_tag = one_dataset.builder_name
# Those are not real datasets from the Hub so we exclude them. # Those are not real datasets from the Hub so we exclude them.
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
if dataset_metadata is None:
dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
if dataset_tags is None: if dataset_tags is None:
dataset_tags = [default_tag] dataset_tags = [default_tag]
if dataset_args is None: if dataset_args is None:
...@@ -618,9 +629,10 @@ class TrainingSummary: ...@@ -618,9 +629,10 @@ class TrainingSummary:
model_name=model_name, model_name=model_name,
finetuned_from=finetuned_from, finetuned_from=finetuned_from,
tasks=tasks, tasks=tasks,
dataset_tags=dataset_tags,
dataset=dataset, dataset=dataset,
dataset_tags=dataset_tags,
dataset_args=dataset_args, dataset_args=dataset_args,
dataset_metadata=dataset_metadata,
eval_results=eval_results, eval_results=eval_results,
eval_lines=eval_lines, eval_lines=eval_lines,
hyperparameters=hyperparameters, hyperparameters=hyperparameters,
...@@ -751,7 +763,7 @@ def parse_log_history(log_history): ...@@ -751,7 +763,7 @@ def parse_log_history(log_history):
while idx >= 0 and "eval_loss" not in log_history[idx]: while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1 idx -= 1
if idx > 0: if idx >= 0:
return None, None, log_history[idx] return None, None, log_history[idx]
else: else:
return None, None, None return None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment