Unverified Commit 492bb6aa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer multi label (#7191)

* Trainer accep multiple labels

* Missing import

* Fix dosctrings
parent 70974592
...@@ -31,6 +31,7 @@ from .integrations import ( ...@@ -31,6 +31,7 @@ from .integrations import (
run_hp_search_optuna, run_hp_search_optuna,
run_hp_search_ray, run_hp_search_ray,
) )
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
...@@ -45,6 +46,9 @@ from .trainer_utils import ( ...@@ -45,6 +46,9 @@ from .trainer_utils import (
default_hp_space, default_hp_space,
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
nested_concat,
nested_numpify,
nested_xla_mesh_reduce,
set_seed, set_seed,
) )
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -293,6 +297,12 @@ class Trainer: ...@@ -293,6 +297,12 @@ class Trainer:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
if self.args.label_names is None:
self.args.label_names = (
["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"]
)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
...@@ -1307,9 +1317,9 @@ class Trainer: ...@@ -1307,9 +1317,9 @@ class Trainer:
if loss is not None: if loss is not None:
eval_losses.extend([loss] * batch_size) eval_losses.extend([loss] * batch_size)
if logits is not None: if logits is not None:
preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits)) preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None: if labels is not None:
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0) label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
if self.args.past_index and hasattr(self, "_past"): if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
...@@ -1318,25 +1328,23 @@ class Trainer: ...@@ -1318,25 +1328,23 @@ class Trainer:
if self.args.local_rank != -1: if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes: # In distributed mode, concatenate all results from all nodes:
if preds is not None: if preds is not None:
preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds) preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None: if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available(): elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None: if preds is not None:
preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds)) preds = nested_xla_mesh_reduce("eval_preds", preds)
if label_ids is not None: if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None: if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
# Finally, turn the aggregated tensors into numpy arrays. # Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: if preds is not None:
preds = tuple(p.cpu().numpy() for p in preds) preds = nested_numpify(preds)
if len(preds) == 1:
preds = preds[0]
if label_ids is not None: if label_ids is not None:
label_ids = label_ids.cpu().numpy() label_ids = nested_numpify(label_ids)
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
...@@ -1382,8 +1390,7 @@ class Trainer: ...@@ -1382,8 +1390,7 @@ class Trainer:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional). A tuple with the loss, logits and labels (each being optional).
""" """
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
with torch.no_grad(): with torch.no_grad():
...@@ -1402,10 +1409,18 @@ class Trainer: ...@@ -1402,10 +1409,18 @@ class Trainer:
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)
labels = inputs.get("labels") logits = tuple(logit.detach() for logit in logits)
if labels is not None: if len(logits) == 1:
labels = labels.detach() logits = logits[0]
return (loss, tuple(l.detach() for l in logits), labels)
if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
if len(labels) == 1:
labels = labels[0]
else:
labels = None
return (loss, logits, labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
""" """
......
...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union ...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
...@@ -132,9 +132,49 @@ default_hp_space = { ...@@ -132,9 +132,49 @@ default_hp_space = {
} }
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
if is_torch_available():
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
else:
raise ImportError("Torch must be installed to use `nested_concat`")
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor": def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available(): if is_torch_available():
try: try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor) torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
......
...@@ -2,7 +2,7 @@ import dataclasses ...@@ -2,7 +2,7 @@ import dataclasses
import json import json
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .utils import logging from .utils import logging
...@@ -128,6 +128,12 @@ class TrainingArguments: ...@@ -128,6 +128,12 @@ class TrainingArguments:
forward method. forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.) (Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
label_names (:obj:`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -253,13 +259,16 @@ class TrainingArguments: ...@@ -253,13 +259,16 @@ class TrainingArguments:
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."} default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
) )
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
remove_unused_columns: Optional[bool] = field( remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
) )
label_names: Optional[List[str]] = field(
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
......
...@@ -24,17 +24,21 @@ PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" ...@@ -24,17 +24,21 @@ PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
class RegressionDataset: class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42): def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed) np.random.seed(seed)
self.label_names = ["labels"] if label_names is None else label_names
self.length = length self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32) self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)) self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
self.ys = [y.astype(np.float32) for y in self.ys]
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
return {"input_x": self.x[i], "label": self.y[i]} result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
result["input_x"] = self.x[i]
return result
class AlmostAccuracy: class AlmostAccuracy:
...@@ -68,7 +72,7 @@ if is_torch_available(): ...@@ -68,7 +72,7 @@ if is_torch_available():
self.double_output = double_output self.double_output = double_output
self.config = None self.config = None
def forward(self, input_x=None, labels=None): def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b y = input_x * self.a + self.b
if labels is None: if labels is None:
return (y, y) if self.double_output else (y,) return (y, y) if self.double_output else (y,)
...@@ -76,8 +80,9 @@ if is_torch_available(): ...@@ -76,8 +80,9 @@ if is_torch_available():
return (loss, y, y) if self.double_output else (loss, y) return (loss, y, y) if self.double_output else (loss, y)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs): def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
train_dataset = RegressionDataset(length=train_len) label_names = kwargs.get("label_names", None)
eval_dataset = RegressionDataset(length=eval_len) train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
model = RegressionModel(a, b, double_output) model = RegressionModel(a, b, double_output)
compute_metrics = kwargs.pop("compute_metrics", None) compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None) data_collator = kwargs.pop("data_collator", None)
...@@ -174,7 +179,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -174,7 +179,7 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy()) trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
results = trainer.evaluate() results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.y x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5 pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean() expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss) self.assertAlmostEqual(results["eval_loss"], expected_loss)
...@@ -185,7 +190,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -185,7 +190,7 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy()) trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
results = trainer.evaluate() results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.y x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5 pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean() expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss) self.assertAlmostEqual(results["eval_loss"], expected_loss)
...@@ -212,6 +217,18 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -212,6 +217,18 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
# With more than one output/label of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"])
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
def test_trainer_with_datasets(self): def test_trainer_with_datasets(self):
np.random.seed(42) np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32) x = np.random.normal(size=(64,)).astype(np.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment