Unverified Commit a135f595 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Auto modelcard (#11599)



* Autogenerate model cards from the Trainer

* ModelCard deprecated

* Fix test

* Style

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Address review comments

* Quality

* With all metadata

* Metadata

* Post-merge conflict mess

* Data args and all examples

* Default license and languages when possible
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b3429ab6
......@@ -447,7 +447,16 @@ def main():
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-generation"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -476,7 +476,16 @@ def main():
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -452,7 +452,16 @@ def main():
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "language-modeling"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -428,7 +428,14 @@ def main():
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(
finetuned_from=model_args.model_name_or_path,
tags="multiple-choice",
dataset_tags="swag",
dataset_args="regular",
dataset="SWAG",
language="en",
)
def _mp_fn(index):
......
......@@ -601,7 +601,16 @@ def main():
trainer.save_metrics("predict", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -640,7 +640,16 @@ def main():
trainer.save_metrics("predict", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "question-answering"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -583,7 +583,16 @@ def main():
writer.write("\n".join(predictions))
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "summarization"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
return results
......
......@@ -516,7 +516,14 @@ def main():
writer.write(f"{index}\t{item}\n")
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-classification"}
if data_args.task_name is not None:
kwargs["language"] = "en"
kwargs["dataset_tags"] = "glue"
kwargs["dataset_args"] = data_args.task_name
kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -491,7 +491,16 @@ def main():
writer.write(" ".join(prediction) + "\n")
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "token-classification"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
trainer.push_to_hub(**kwargs)
def _mp_fn(index):
......
......@@ -575,7 +575,20 @@ def main():
writer.write("\n".join(predictions))
if training_args.push_to_hub:
trainer.push_to_hub()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "translation"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
trainer.push_to_hub(**kwargs)
return results
......
......@@ -18,7 +18,15 @@
import copy
import json
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import requests
from huggingface_hub import HfApi
from . import __version__
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
......@@ -26,9 +34,14 @@ from .file_utils import (
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tokenizers_available,
is_torch_available,
)
from .models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .training_args import ParallelMode
from .utils import logging
......@@ -49,6 +62,9 @@ class ModelCard:
"""
def __init__(self, **kwargs):
warnings.warn(
"The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
)
# Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop("model_details", {})
self.intended_use = kwargs.pop("intended_use", {})
......@@ -218,3 +234,403 @@ class ModelCard:
"""Save this instance to a json file."""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
AUTOGENERATED_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
TASK_TAG_TO_NAME_MAPPING = {
"fill-mask": "Masked Language Modeling",
"multiple-choice": "Multiple Choice",
"question-answering": "Question Answering",
"summarization": "Summarization",
"text-classification": "Text Classification",
"text-generation": "Causal Language Modeling",
"text2text-generation": "Sequence-to-sequence Language Modeling",
"token-classification": "Token Classification",
"translation": "Translation",
"zero-shot-classification": "Zero Shot Classification",
}
METRIC_TAGS = [
"accuracy",
"bleu",
"f1",
"matthews_correlation",
"pearsonr",
"precision",
"recall",
"rouge",
"sacrebleu",
"spearmanr",
]
def _listify(obj):
if obj is None:
return []
elif isinstance(obj, str):
return [obj]
else:
return obj
def _list_possibilities(name, tags):
if tags is None:
return ""
if isinstance(tags, str):
tags = [tags]
if len(tags) == 0:
return ""
name_tags = [f"- {tag}" for tag in tags]
return f"{name}:\n" + "\n".join(name_tags) + "\n"
def infer_metric_tags_from_eval_results(eval_results):
if eval_results is None:
return {}
result = {}
for key in eval_results.keys():
if key.lower().replace(" ", "_") in METRIC_TAGS:
result[key.lower().replace(" ", "_")] = key
elif key.lower() == "rouge1":
result["rouge"] = key
return result
@dataclass
class TrainingSummary:
model_name: str
language: Optional[Union[str, List[str]]] = None
license: Optional[str] = None
tags: Optional[Union[str, List[str]]] = None
finetuned_from: Optional[str] = None
dataset: Optional[Union[str, List[str]]] = None
dataset_tags: Optional[Union[str, List[str]]] = None
dataset_args: Optional[Union[str, List[str]]] = None
eval_results: Optional[Dict[str, float]] = None
eval_lines: Optional[List[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None
def __post_init__(self):
# Infer default license from the checkpoint used, if possible.
if self.license is None and not is_offline_mode() and self.finetuned_from is not None:
try:
model_info = HfApi().model_info(self.finetuned_from)
for tag in model_info.tags:
if tag.startswith("license:"):
self.license = tag[8:]
except requests.exceptions.HTTPError:
pass
def create_model_index(self, metric_mapping):
model_index = f"model-index:\n- name: {self.model_name}\n"
# Dataset mapping tag -> name
dataset_names = _listify(self.dataset)
dataset_tags = _listify(self.dataset_tags)
dataset_args = _listify(self.dataset_args)
if len(dataset_args) < len(dataset_tags):
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
dataset_mapping = {tag: name for tag, name in zip(dataset_tags, dataset_names)}
dataset_arg_mapping = {tag: arg for tag, arg in zip(dataset_tags, dataset_args)}
task_mapping = {
tag: TASK_TAG_TO_NAME_MAPPING[tag] for tag in _listify(self.tags) if tag in TASK_TAG_TO_NAME_MAPPING
}
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
return model_index
if len(task_mapping) == 0:
task_mapping = {None: None}
if len(dataset_mapping) == 0:
dataset_mapping = {None: None}
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
model_index += " results:\n"
for task_tag, ds_tag in all_possibilities:
result = ""
if task_tag is not None:
result += f" - task:\n name: {task_mapping[task_tag]}\n type: {task_tag}\n"
if ds_tag is not None:
prefix = " - " if task_tag is None else " "
result += f"{prefix}dataset:\n name: {dataset_mapping[ds_tag]}\n type: {ds_tag}\n"
if dataset_arg_mapping[ds_tag] is not None:
result += f" args: {dataset_arg_mapping[ds_tag]}\n"
if len(metric_mapping) > 0:
result += " metrics:\n"
for metric_tag, metric_name in metric_mapping.items():
value = self.eval_results[metric_name]
result += f" - name: {metric_name}\n type: {metric_tag}\n value: {value}\n"
model_index += result
return model_index
def to_model_card(self):
model_card = ""
metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
# Metadata
metadata = ""
metadata += _list_possibilities("language", self.language)
if self.license is not None:
metadata += f"license: {self.license}\n"
metadata += _list_possibilities("tags", self.tags)
metadata += _list_possibilities("datasets", self.dataset_tags)
metadata += _list_possibilities("metrics", list(metric_mapping.keys()))
metadata += "\n" + self.create_model_index(metric_mapping)
if len(metadata) > 0:
model_card = f"---\n{metadata}---\n"
# Now the model card for realsies.
model_card += AUTOGENERATED_COMMENT
model_card += f"\n# {self.model_name}\n\n"
if self.finetuned_from is None:
model_card += "This model was trained from scratch on "
else:
model_card += f"This model is a fine-tuned version of [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
if self.dataset is None:
model_card += "an unkown dataset."
else:
if isinstance(self.dataset, str):
model_card += f"the {self.dataset} dataset."
else:
model_card += (
", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
)
if self.eval_results is not None:
model_card += "\nIt achieves the following results on the evaluation set:\n"
model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
model_card += "\n"
model_card += "\n## Model description\n\nMore information needed\n"
model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
model_card += "\n## Training and evaluation data\n\nMore information needed\n"
model_card += "\n## Training procedure\n"
model_card += "\n### Training hyperparameters\n"
if self.hyperparameters is not None:
model_card += "\nThe following hyperparameters were used during training:\n"
model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
model_card += "\n"
else:
model_card += "\nMore information needed\n"
if self.eval_lines is not None:
model_card += "\n### Training results\n\n"
model_card += make_markdown_table(self.eval_lines)
model_card += "\n"
model_card += "\n### Framework versions\n\n"
model_card += f"- Transformers {__version__}\n"
if is_torch_available():
import torch
model_card += f"- Pytorch {torch.__version__}\n"
if is_datasets_available():
import datasets
model_card += f"- Datasets {datasets.__version__}\n"
if is_tokenizers_available():
import tokenizers
model_card += f"- Tokenizers {tokenizers.__version__}\n"
return model_card
@classmethod
def from_trainer(
cls,
trainer,
language=None,
license=None,
tags=None,
model_name=None,
finetuned_from=None,
dataset_tags=None,
dataset=None,
dataset_args=None,
):
# TODO (Sylvain) Add a default for `pipeline-tag` inferred from the model.
if model_name is None:
model_name = Path(trainer.args.output_dir).name
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
hyperparameters = extract_hyperparameters_from_trainer(trainer)
return cls(
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_args=dataset_args,
eval_results=eval_results,
eval_lines=eval_lines,
hyperparameters=hyperparameters,
)
def parse_log_history(log_history):
"""
Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
"""
idx = 0
while idx < len(log_history) and "train_runtime" not in log_history[idx]:
idx += 1
# If there are no training logs
if idx == len(log_history):
idx -= 1
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
if idx > 0:
return None, None, log_history[idx]
else:
return None, None, None
# From now one we can assume we have training logs:
train_log = log_history[idx]
lines = []
training_loss = "No log"
for i in range(idx):
if "loss" in log_history[i]:
training_loss = log_history[i]["loss"]
if "eval_loss" in log_history[i]:
metrics = log_history[i].copy()
_ = metrics.pop("total_flos", None)
epoch = metrics.pop("epoch", None)
step = metrics.pop("step", None)
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items():
if k == "eval_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
values[name] = v
lines.append(values)
idx = len(log_history) - 1
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
if idx > 0:
eval_results = {}
for key, value in log_history[idx].items():
if key.startswith("eval_"):
key = key[5:]
if key not in ["runtime", "samples_per_second", "epoch", "step"]:
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
eval_results[camel_cased_key] = value
return train_log, lines, eval_results
else:
return train_log, lines, None
def _maybe_round(v, decimals=4):
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
return f"{v:.{decimals}f}"
return str(v)
def _regular_table_line(values, col_widths):
values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
return "".join(values_with_space) + "|\n"
def _second_table_line(col_widths):
values = ["|:" + "-" * w + ":" for w in col_widths]
return "".join(values) + "|\n"
def make_markdown_table(lines):
"""
Create a nice Markdown table from the results in `lines`.
"""
if lines is None or len(lines) == 0:
return ""
col_widths = {key: len(str(key)) for key in lines[0].keys()}
for line in lines:
for key, value in line.items():
if col_widths[key] < len(_maybe_round(value)):
col_widths[key] = len(_maybe_round(value))
table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
table += _second_table_line(list(col_widths.values()))
for line in lines:
table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
return table
_TRAINING_ARGS_KEYS = [
"learning_rate",
"train_batch_size",
"eval_batch_size",
"seed",
]
def extract_hyperparameters_from_trainer(trainer):
hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
hyperparameters["distributed_type"] = (
"multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
)
if trainer.args.world_size > 1:
hyperparameters["num_devices"] = trainer.args.world_size
if trainer.args.gradient_accumulation_steps > 1:
hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
total_train_batch_size = (
trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
)
if total_train_batch_size != hyperparameters["train_batch_size"]:
hyperparameters["total_train_batch_size"] = total_train_batch_size
total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
if total_eval_batch_size != hyperparameters["eval_batch_size"]:
hyperparameters["total_eval_batch_size"] = total_eval_batch_size
if trainer.args.adafactor:
hyperparameters["optimizer"] = "Adafactor"
else:
hyperparameters[
"optimizer"
] = f"Adam with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and epsilon={trainer.args.adam_epsilon}"
hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
if trainer.args.warmup_ratio != 0.0:
hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
if trainer.args.warmup_steps != 0.0:
hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
if trainer.args.max_steps != -1:
hyperparameters["training_steps"] = trainer.args.max_steps
else:
hyperparameters["num_epochs"] = trainer.args.num_train_epochs
if trainer.args.fp16:
if trainer.use_amp:
hyperparameters["mixed_precision_training"] = "Native AMP"
elif trainer._use_apex:
hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
if trainer.args.label_smoothing_factor != 0.0:
hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
return hyperparameters
......@@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
......@@ -384,12 +383,6 @@ def pipeline(
model = get_default_model(targeted_task, framework, task_options)
model_name = model if isinstance(model, str) else None
modelcard = None
# Try to infer modelcard from model or config name (if provided as str)
if isinstance(model, str):
modelcard = model
elif isinstance(config, str):
modelcard = config
# Infer the framework form the model
if framework is None:
......@@ -404,10 +397,6 @@ def pipeline(
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
# Instantiate modelcard if needed
if isinstance(modelcard, str):
modelcard = ModelCard.from_pretrained(modelcard, revision=revision, _from_pipeline=task)
# Instantiate model if needed
if isinstance(model, str):
# Handle transparent TF/PT model conversion
......@@ -504,10 +493,4 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(
model=model,
modelcard=modelcard,
framework=framework,
task=task,
**kwargs,
)
return task_class(model=model, framework=framework, task=task, **kwargs)
......@@ -74,6 +74,7 @@ from .file_utils import (
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
......@@ -2381,25 +2382,49 @@ class Trainer:
else:
return 0
def create_model_card(
self,
language: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
model_name: Optional[str] = None,
finetuned_from: Optional[str] = None,
dataset_tags: Optional[Union[str, List[str]]] = None,
dataset: Optional[Union[str, List[str]]] = None,
dataset_args: Optional[Union[str, List[str]]] = None,
):
training_summary = TrainingSummary.from_trainer(
self,
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_args=dataset_args,
)
model_card = training_summary.to_model_card()
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
f.write(model_card)
def push_to_hub(
self,
save_directory: Optional[str] = None,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = "add model",
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
**kwargs,
):
"""
Upload `self.model` to the 🤗 model hub.
Parameters:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Folder containing the model weights and config. Will default to :obj:`self.args.output_dir`.
repo_name (:obj:`str`, `optional`):
Repository name for your model or tokenizer in the hub. If not specified, the repository name will be
the stem of :obj:`save_directory`.
Repository name for your model or tokenizer in the hub. If not specified and :obj:`repo_url` is not
specified either, will default to the stem of :obj:`self.args.output_dir`.
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
......@@ -2415,6 +2440,8 @@ class Trainer:
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
:obj:`True` if :obj:`repo_url` is not specified.
kwargs:
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
Returns:
The url of the commit of your model in the given repository.
......@@ -2426,15 +2453,23 @@ class Trainer:
raise ValueError(
"The `upload_model_to_hub` method only works for models that inherit from `PushToHubMixin` models."
)
if save_directory is None:
save_directory = self.args.output_dir
# To avoid pushing all checkpoints, we just copy all the files in save_directory in a tmp dir.
if repo_url is None and repo_name is None:
repo_name = Path(self.args.output_dir).name
if repo_name is not None:
model_name = repo_name
elif repo_url is not None:
model_name = repo_url.split("/")[-1]
else:
model_name = None
self.create_model_card(model_name=model_name, **kwargs)
with tempfile.TemporaryDirectory() as tmp_dir:
for f in os.listdir(save_directory):
fname = os.path.join(save_directory, f)
if os.path.isfile(fname):
shutil.copy(fname, os.path.join(tmp_dir, f))
shutil.copy(os.path.join(self.args.output_dir, "README.md"), os.path.join(tmp_dir, "README.md"))
unwrap_model(self.model).save_pretrained(tmp_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(tmp_dir)
return unwrap_model(self.model)._push_to_hub(
save_directory=tmp_dir,
......
......@@ -1168,7 +1168,6 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
trainer.save_model()
url = trainer.push_to_hub(repo_name="test-trainer", use_auth_token=self._token)
# Extract repo_name from the url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment