Unverified Commit 5506d049 authored by Nathan Fradet's avatar Nathan Fradet Committed by GitHub
Browse files

Seq2seq trainer generation config arg (#22323)



* seq2seq trainer and training arguments accepting GenerationConfig arg

* seq2seq Trainer and training arguments docstring fixes

* Update training_args_seq2seq.py docstring
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Fixing trainer_seq2seq.py docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* seq2seq trainer: legacy gen args back & GenerationConfig created at init

* Seq2seq trainer: fix in case gen_config.max_new_tokens is None
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* seq2seq trainer: adding legacy arg retrocompatibility

* seq2seq trainer and training arguments accepting GenerationConfig arg

* seq2seq Trainer and training arguments docstring fixes

* Update training_args_seq2seq.py docstring
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Fixing trainer_seq2seq.py docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* seq2seq trainer: legacy gen args back & GenerationConfig created at init

* Seq2seq trainer: fix in case gen_config.max_new_tokens is None
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* seq2seq trainer: adding legacy arg retrocompatibility

* seq2seq trainer: evaluate and predict untouched

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* seq2seq trainer: adding init args, keeping IDEs hints

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 03966cac
......@@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import Dataset
from .data.data_collator import DataCollator
from .deepspeed import is_deepspeed_zero3_enabled
from .generation.configuration_utils import GenerationConfig
from .modeling_utils import PreTrainedModel
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer import Trainer
from .trainer_utils import PredictionOutput
from .trainer_callback import TrainerCallback
from .trainer_utils import EvalPrediction, PredictionOutput
from .training_args import TrainingArguments
from .utils import logging
......@@ -28,6 +36,76 @@ logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Override self.model.generation_config if a GenerationConfig is specified in args.
# Priority: args.generation_config > model.generation_config > default GenerationConfig.
if self.args.generation_config is not None:
gen_config = self.load_generation_config(self.args.generation_config)
self.model.generation_config = gen_config
@staticmethod
def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig:
"""
Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
Args:
gen_config_arg (`str` or [`~generation.GenerationConfig`]):
`Seq2SeqTrainingArguments.generation_config` argument.
Returns:
A `~generation.GenerationConfig`.
"""
# GenerationConfig provided, nothing to do
if isinstance(gen_config_arg, GenerationConfig):
return deepcopy(gen_config_arg)
# str or Path
pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg
config_file_name = None
# Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL
# This step is required in order to determine config_file_name
if pretrained_model_name.is_file():
config_file_name = pretrained_model_name.name
pretrained_model_name = pretrained_model_name.parent
# dir path
elif pretrained_model_name.is_dir():
pass
# model id or URL
else:
pretrained_model_name = gen_config_arg
gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name)
return gen_config
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
......@@ -171,6 +249,8 @@ class Seq2SeqTrainer(Trainer):
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()
gen_kwargs = self._gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
......@@ -192,13 +272,14 @@ class Seq2SeqTrainer(Trainer):
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
if self.model.generation_config._from_model_config:
self.model.generation_config._from_model_config = False
# Retrieves GenerationConfig from model.generation_config
gen_config = model.generation_config
# in case the batch is shorter than max length, the output should be padded
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
elif generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
with torch.no_grad():
if has_labels:
......@@ -212,20 +293,18 @@ class Seq2SeqTrainer(Trainer):
loss = None
if self.args.prediction_loss_only:
return (loss, None, None)
return loss, None, None
if has_labels:
labels = inputs["labels"]
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
if labels.shape[-1] < gen_config.max_length:
labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
else:
labels = None
return (loss, generated_tokens, labels)
return loss, generated_tokens, labels
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
......
......@@ -14,8 +14,10 @@
import logging
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
from typing import Optional, Union
from .generation.configuration_utils import GenerationConfig
from .training_args import TrainingArguments
from .utils import add_start_docstrings
......@@ -42,6 +44,15 @@ class Seq2SeqTrainingArguments(TrainingArguments):
generation_num_beams (`int`, *optional*):
The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
`num_beams` value of the model configuration.
generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*):
Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
- a [`~generation.GenerationConfig`] object.
"""
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
......@@ -66,3 +77,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
)
},
)
generation_config: Optional[Union[str, Path, GenerationConfig]] = field(
default=None,
metadata={
"help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction."
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment