Unverified Commit 7186ca62 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Multi predictions trainer (#7126)

* Allow multiple outputs

* Formatting

* Move the unwrapping before metrics

* Fix typo

* Add test for non-supported config options
parent 52d250f6
...@@ -1269,6 +1269,13 @@ class Trainer: ...@@ -1269,6 +1269,13 @@ class Trainer:
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
) )
assert not getattr(
self.model.config, "output_attentions", False
), "The prediction loop does not work with `output_attentions=True`."
assert not getattr(
self.model.config, "output_hidden_states", False
), "The prediction loop does not work with `output_hidden_states=True`."
model = self.model model = self.model
# multi-gpu eval # multi-gpu eval
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
...@@ -1300,7 +1307,7 @@ class Trainer: ...@@ -1300,7 +1307,7 @@ class Trainer:
if loss is not None: if loss is not None:
eval_losses.extend([loss] * batch_size) eval_losses.extend([loss] * batch_size)
if logits is not None: if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0) preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits))
if labels is not None: if labels is not None:
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0) label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
...@@ -1311,13 +1318,13 @@ class Trainer: ...@@ -1311,13 +1318,13 @@ class Trainer:
if self.args.local_rank != -1: if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes: # In distributed mode, concatenate all results from all nodes:
if preds is not None: if preds is not None:
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds)
if label_ids is not None: if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available(): elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None: if preds is not None:
preds = xm.mesh_reduce("eval_preds", preds, torch.cat) preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds))
if label_ids is not None: if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None: if eval_losses is not None:
...@@ -1325,7 +1332,9 @@ class Trainer: ...@@ -1325,7 +1332,9 @@ class Trainer:
# Finally, turn the aggregated tensors into numpy arrays. # Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: if preds is not None:
preds = preds.cpu().numpy() preds = tuple(p.cpu().numpy() for p in preds)
if len(preds) == 1:
preds = preds[0]
if label_ids is not None: if label_ids is not None:
label_ids = label_ids.cpu().numpy() label_ids = label_ids.cpu().numpy()
...@@ -1380,11 +1389,13 @@ class Trainer: ...@@ -1380,11 +1389,13 @@ class Trainer:
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
if has_labels: if has_labels:
loss, logits = outputs[:2] # The .mean() is to reduce in case of distributed training
loss = loss.mean().item() loss = outputs[0].mean().item()
logits = outputs[1:]
else: else:
loss = None loss = None
logits = outputs[0] # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
logits = outputs[:]
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
...@@ -1394,7 +1405,7 @@ class Trainer: ...@@ -1394,7 +1405,7 @@ class Trainer:
labels = inputs.get("labels") labels = inputs.get("labels")
if labels is not None: if labels is not None:
labels = labels.detach() labels = labels.detach()
return (loss, logits.detach(), labels) return (loss, tuple(l.detach() for l in logits), labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
""" """
......
import random import random
from typing import Any, Dict, List, NamedTuple, Optional, Union from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -42,12 +42,12 @@ class EvalPrediction(NamedTuple): ...@@ -42,12 +42,12 @@ class EvalPrediction(NamedTuple):
label_ids (:obj:`np.ndarray`): Targets to be matched. label_ids (:obj:`np.ndarray`): Targets to be matched.
""" """
predictions: np.ndarray predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: np.ndarray label_ids: np.ndarray
class PredictionOutput(NamedTuple): class PredictionOutput(NamedTuple):
predictions: np.ndarray predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray] label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]] metrics: Optional[Dict[str, float]]
......
...@@ -61,22 +61,24 @@ if is_torch_available(): ...@@ -61,22 +61,24 @@ if is_torch_available():
return iter(self.parse_file()) return iter(self.parse_file())
class RegressionModel(torch.nn.Module): class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0): def __init__(self, a=0, b=0, double_output=False):
super().__init__() super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float()) self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float()) self.b = torch.nn.Parameter(torch.tensor(b).float())
self.double_output = double_output
self.config = None
def forward(self, input_x=None, labels=None): def forward(self, input_x=None, labels=None):
y = input_x * self.a + self.b y = input_x * self.a + self.b
if labels is None: if labels is None:
return (y,) return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels) loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y) return (loss, y, y) if self.double_output else (loss, y)
def get_regression_trainer(a=0, b=0, train_len=64, eval_len=64, **kwargs): def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
train_dataset = RegressionDataset(length=train_len) train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len) eval_dataset = RegressionDataset(length=eval_len)
model = RegressionModel(a, b) model = RegressionModel(a, b, double_output)
compute_metrics = kwargs.pop("compute_metrics", None) compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None) data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None)) optimizers = kwargs.pop("optimizers", (None, None))
...@@ -202,6 +204,14 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -202,6 +204,14 @@ class TrainerIntegrationTest(unittest.TestCase):
x = trainer.eval_dataset.x x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With more than one output of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
def test_trainer_with_datasets(self): def test_trainer_with_datasets(self):
np.random.seed(42) np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32) x = np.random.normal(size=(64,)).astype(np.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment