"docs/source/vscode:/vscode.git/clone" did not exist on "40658be4615d7d30ad6519618ee984cdba263098"
Unverified Commit 56e6487c authored by lmagne's avatar lmagne Committed by GitHub
Browse files

add dataset split and config to model-index in TrainingSummary.from_trainer (#18064)



* added metadata to training summary

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent fde22c75
......@@ -384,6 +384,7 @@ class TrainingSummary:
dataset: Optional[Union[str, List[str]]] = None
dataset_tags: Optional[Union[str, List[str]]] = None
dataset_args: Optional[Union[str, List[str]]] = None
dataset_metadata: Optional[Dict[str, Any]] = None
eval_results: Optional[Dict[str, float]] = None
eval_lines: Optional[List[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None
......@@ -412,10 +413,12 @@ class TrainingSummary:
dataset_names = _listify(self.dataset)
dataset_tags = _listify(self.dataset_tags)
dataset_args = _listify(self.dataset_args)
dataset_metadata = _listify(self.dataset_metadata)
if len(dataset_args) < len(dataset_tags):
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
dataset_mapping = {tag: name for tag, name in zip(dataset_tags, dataset_names)}
dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)}
dataset_metadata_mapping = {tag: metadata for tag, metadata in zip(dataset_tags, dataset_metadata)}
task_mapping = {
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
......@@ -438,7 +441,12 @@ class TrainingSummary:
result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
if ds_tag is not None:
result["dataset"] = {"name": dataset_mapping[ds_tag], "type": ds_tag}
metadata = dataset_metadata_mapping.get(ds_tag, {})
result["dataset"] = {
"name": dataset_mapping[ds_tag],
"type": ds_tag,
**metadata,
}
if dataset_arg_mapping[ds_tag] is not None:
result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
......@@ -565,6 +573,7 @@ class TrainingSummary:
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset_metadata=None,
dataset=None,
dataset_args=None,
):
......@@ -574,6 +583,8 @@ class TrainingSummary:
default_tag = one_dataset.builder_name
# Those are not real datasets from the Hub so we exclude them.
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
if dataset_metadata is None:
dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
if dataset_tags is None:
dataset_tags = [default_tag]
if dataset_args is None:
......@@ -618,9 +629,10 @@ class TrainingSummary:
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_tags=dataset_tags,
dataset_args=dataset_args,
dataset_metadata=dataset_metadata,
eval_results=eval_results,
eval_lines=eval_lines,
hyperparameters=hyperparameters,
......@@ -751,7 +763,7 @@ def parse_log_history(log_history):
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
if idx > 0:
if idx >= 0:
return None, None, log_history[idx]
else:
return None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment