Unverified Commit 4208f496 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Better filtering of the model outputs in Trainer (#8633)

* Better filtering of the model outputs in Trainer

* Fix examples tests

* Add test for Lysandre
parent f2e07e72
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer): ...@@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer):
return loss return loss
def prediction_step( def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform an evaluation step on :obj:`model` using obj:`inputs`. Perform an evaluation step on :obj:`model` using obj:`inputs`.
......
...@@ -43,6 +43,8 @@ class PretrainedConfig(object): ...@@ -43,6 +43,8 @@ class PretrainedConfig(object):
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig` the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`. like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
- **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at
dictionary outputs of the model during inference.
Args: Args:
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
......
...@@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig): ...@@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig):
:obj:`True` for `bart-large-cnn`. :obj:`True` for `bart-large-cnn`.
""" """
model_type = "bart" model_type = "bart"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig):
""" """
model_type = "ctrl" model_type = "ctrl"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig): ...@@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig):
""" """
model_type = "gpt2" model_type = "gpt2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -97,3 +97,4 @@ class MarianConfig(BartConfig): ...@@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
""" """
model_type = "marian" model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
...@@ -102,3 +102,4 @@ class MBartConfig(BartConfig): ...@@ -102,3 +102,4 @@ class MBartConfig(BartConfig):
""" """
model_type = "mbart" model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
...@@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig): ...@@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
""" """
model_type = "mt5" model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -141,4 +141,5 @@ class PegasusConfig(BartConfig): ...@@ -141,4 +141,5 @@ class PegasusConfig(BartConfig):
""" """
model_type = "pegasus" model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
# The implementation of the config object is in BartConfig # The implementation of the config object is in BartConfig
...@@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed. smoothing is performed.
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig): ...@@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig):
>>> configuration = model.config >>> configuration = model.config
""" """
model_type = "reformer" model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]
def __init__( def __init__(
self, self,
......
...@@ -71,6 +71,7 @@ class T5Config(PretrainedConfig): ...@@ -71,6 +71,7 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
""" """
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
""" """
model_type = "transfo-xl" model_type = "transfo-xl"
keys_to_ignore_at_inference = ["mems"]
def __init__( def __init__(
self, self,
......
...@@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig):
""" """
model_type = "xlnet" model_type = "xlnet"
keys_to_ignore_at_inference = ["mems"]
def __init__( def __init__(
self, self,
......
...@@ -1098,10 +1098,11 @@ class Trainer: ...@@ -1098,10 +1098,11 @@ class Trainer:
""" """
outputs = model(**inputs) outputs = model(**inputs)
# Save past state if it exists # Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0] return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool: def is_local_process_zero(self) -> bool:
""" """
...@@ -1220,7 +1221,9 @@ class Trainer: ...@@ -1220,7 +1221,9 @@ class Trainer:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
) -> Dict[str, float]:
""" """
Run evaluation and returns metrics. Run evaluation and returns metrics.
...@@ -1234,6 +1237,9 @@ class Trainer: ...@@ -1234,6 +1237,9 @@ class Trainer:
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method. :obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
...@@ -1250,6 +1256,7 @@ class Trainer: ...@@ -1250,6 +1256,7 @@ class Trainer:
# No point gathering the predictions if there are no metrics, otherwise we defer to # No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None, prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
) )
self.log(output.metrics) self.log(output.metrics)
...@@ -1261,7 +1268,7 @@ class Trainer: ...@@ -1261,7 +1268,7 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput: def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
""" """
Run prediction and returns predictions and potential metrics. Run prediction and returns predictions and potential metrics.
...@@ -1272,6 +1279,9 @@ class Trainer: ...@@ -1272,6 +1279,9 @@ class Trainer:
test_dataset (:obj:`Dataset`): test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
.. note:: .. note::
...@@ -1291,10 +1301,14 @@ class Trainer: ...@@ -1291,10 +1301,14 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction") return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
def prediction_loop( def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
...@@ -1346,7 +1360,7 @@ class Trainer: ...@@ -1346,7 +1360,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None: if loss is not None:
losses = loss.repeat(batch_size) losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
...@@ -1410,7 +1424,11 @@ class Trainer: ...@@ -1410,7 +1424,11 @@ class Trainer:
return nested_numpify(tensors) return nested_numpify(tensors)
def prediction_step( def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform an evaluation step on :obj:`model` using obj:`inputs`. Perform an evaluation step on :obj:`model` using obj:`inputs`.
...@@ -1427,6 +1445,9 @@ class Trainer: ...@@ -1427,6 +1445,9 @@ class Trainer:
argument :obj:`labels`. Check your model's documentation for all accepted arguments. argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`): prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only. Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return: Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
...@@ -1434,6 +1455,11 @@ class Trainer: ...@@ -1434,6 +1455,11 @@ class Trainer:
""" """
has_labels = all(inputs.get(k) is not None for k in self.label_names) has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad(): with torch.no_grad():
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
...@@ -1442,16 +1468,21 @@ class Trainer: ...@@ -1442,16 +1468,21 @@ class Trainer:
else: else:
outputs = model(**inputs) outputs = model(**inputs)
if has_labels: if has_labels:
loss = outputs[0].mean().detach() if isinstance(outputs, dict):
logits = outputs[1:] loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else: else:
loss = None loss = None
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`. if isinstance(outputs, dict):
logits = outputs[:] logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)
......
...@@ -44,6 +44,8 @@ if is_torch_available(): ...@@ -44,6 +44,8 @@ if is_torch_available():
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset, LineByLineTextDataset,
PreTrainedModel, PreTrainedModel,
TextDataset, TextDataset,
...@@ -73,6 +75,18 @@ class RegressionDataset: ...@@ -73,6 +75,18 @@ class RegressionDataset:
return result return result
class RepeatDataset:
def __init__(self, x, length=64):
self.x = x
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i):
return {"input_ids": self.x, "labels": self.x}
class DynamicShapesDataset: class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8): def __init__(self, length=64, seed=42, batch_size=8):
self.length = length self.length = length
...@@ -136,6 +150,20 @@ if is_torch_available(): ...@@ -136,6 +150,20 @@ if is_torch_available():
loss = torch.nn.functional.mse_loss(y, labels) loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y) return (loss, y, y) if self.double_output else (loss, y)
class RegressionDictModel(torch.nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
self.config = None
def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
result = {"output": y}
if labels is not None:
result["loss"] = torch.nn.functional.mse_loss(y, labels)
return result
class RegressionPreTrainedModel(PreTrainedModel): class RegressionPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig config_class = RegressionModelConfig
base_model_prefix = "regression" base_model_prefix = "regression"
...@@ -236,6 +264,33 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -236,6 +264,33 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics = trainer.evaluate() metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value) self.assertEqual(metrics[metric], best_value)
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression")
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)
def test_evaluation_with_keys_to_drop(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
eval_dataset = RepeatDataset(x)
args = TrainingArguments("./test")
trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
# By default the past_key_values are removed
result = trainer.predict(eval_dataset)
self.assertTrue(isinstance(result.predictions, np.ndarray))
# We can still get them by setting ignore_keys to []
result = trainer.predict(eval_dataset, ignore_keys=[])
self.assertTrue(isinstance(result.predictions, tuple))
self.assertEqual(len(result.predictions), 2)
def test_training_arguments_are_left_untouched(self): def test_training_arguments_are_left_untouched(self):
trainer = get_regression_trainer() trainer = get_regression_trainer()
trainer.train() trainer.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment