"docs/vscode:/vscode.git/clone" did not exist on "7d6354e04794f3246bf9a0faf4fead080edeebb6"
Unverified Commit d9c62047 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer support for IterableDataset for evaluation and predict (#11286)

* Bulk of the work

* Polish and tests

* Update QA Trainer

* Avoid breaking the predict method

* Deprecation warnings

* Store real eval dataloder

* Get eval dataset reference before wrap
parent e783ea73
...@@ -49,7 +49,7 @@ import torch ...@@ -49,7 +49,7 @@ import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler
...@@ -85,18 +85,22 @@ from .trainer_pt_utils import ( ...@@ -85,18 +85,22 @@ from .trainer_pt_utils import (
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
SequentialDistributedSampler, SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
find_batch_size,
get_parameter_names, get_parameter_names,
nested_concat, nested_concat,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
nested_truncate,
nested_xla_mesh_reduce, nested_xla_mesh_reduce,
reissue_pt_warnings, reissue_pt_warnings,
) )
from .trainer_utils import ( from .trainer_utils import (
PREFIX_CHECKPOINT_DIR, PREFIX_CHECKPOINT_DIR,
BestRun, BestRun,
EvalLoopOutput,
EvalPrediction, EvalPrediction,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
...@@ -381,11 +385,8 @@ class Trainer: ...@@ -381,11 +385,8 @@ class Trainer:
if args.max_steps > 0: if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
self._signature_columns = None self._signature_columns = None
if is_datasets_available(): if is_datasets_available():
...@@ -591,19 +592,33 @@ class Trainer: ...@@ -591,19 +592,33 @@ class Trainer:
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if is_torch_tpu_available(): # Deprecated code
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) if self.args.use_legacy_prediction_loop:
elif is_sagemaker_mp_enabled(): if is_torch_tpu_available():
return SequentialDistributedSampler( return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif is_sagemaker_mp_enabled():
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
else:
return SequentialSampler(eval_dataset)
if self.args.world_size <= 1:
return SequentialSampler(eval_dataset)
else:
return ShardSampler(
eval_dataset, eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size, batch_size=self.args.per_device_eval_batch_size,
num_processes=self.args.world_size,
process_index=self.args.process_index,
) )
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
else:
return SequentialSampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
""" """
...@@ -618,11 +633,27 @@ class Trainer: ...@@ -618,11 +633,27 @@ class Trainer:
""" """
if eval_dataset is None and self.eval_dataset is None: if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation") self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
eval_dataset = IterableDatasetShard(
eval_dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
return DataLoader( return DataLoader(
...@@ -646,10 +677,26 @@ class Trainer: ...@@ -646,10 +677,26 @@ class Trainer:
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
if not isinstance(test_dataset, collections.abc.Sized): if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
raise ValueError("test_dataset must implement __len__")
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test") self._remove_unused_columns(test_dataset, description="test")
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
test_dataset = IterableDatasetShard(
test_dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
test_sampler = self._get_eval_sampler(test_dataset) test_sampler = self._get_eval_sampler(test_dataset)
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
...@@ -983,7 +1030,7 @@ class Trainer: ...@@ -983,7 +1030,7 @@ class Trainer:
else: else:
# see __init__. max_steps is set when the dataset has no __len__ # see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps max_steps = self.args.max_steps
num_train_epochs = 1 num_train_epochs = int(self.args.num_train_epochs)
num_update_steps_per_epoch = max_steps num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
...@@ -1794,13 +1841,11 @@ class Trainer: ...@@ -1794,13 +1841,11 @@ class Trainer:
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker.start() self._memory_tracker.start()
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time() start_time = time.time()
output = self.prediction_loop( eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
eval_dataloader, eval_dataloader,
description="Evaluation", description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to # No point gathering the predictions if there are no metrics, otherwise we defer to
...@@ -1810,8 +1855,7 @@ class Trainer: ...@@ -1810,8 +1855,7 @@ class Trainer:
metric_key_prefix=metric_key_prefix, metric_key_prefix=metric_key_prefix,
) )
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics) self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
...@@ -1860,36 +1904,32 @@ class Trainer: ...@@ -1860,36 +1904,32 @@ class Trainer:
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker.start() self._memory_tracker.start()
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time() start_time = time.time()
output = self.prediction_loop( eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
) )
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset))) output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
self._memory_tracker.stop_and_update_metrics(output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics)
return output return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
def prediction_loop( def evaluation_loop(
self, self,
dataloader: DataLoader, dataloader: DataLoader,
description: str, description: str,
prediction_loss_only: Optional[bool] = None, prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval", metric_key_prefix: str = "eval",
) -> PredictionOutput: ) -> EvalLoopOutput:
""" """
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels. Works both with or without labels.
""" """
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = ( prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
) )
...@@ -1917,53 +1957,75 @@ class Trainer: ...@@ -1917,53 +1957,75 @@ class Trainer:
model = model.half().to(self.args.device) model = model.half().to(self.args.device)
batch_size = dataloader.batch_size batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info(f"***** Running {description} *****") logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {num_examples}") if isinstance(dataloader.dataset, collections.abc.Sized):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}") logger.info(f" Batch size = {batch_size}")
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, self.args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
# a batch size to the sampler)
make_multiple_of = None
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval() model.eval()
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = dataloader.dataset
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
self.callback_handler.eval_dataloader = dataloader # Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
# Update containers on host
if loss is not None: if loss is not None:
losses = loss.repeat(batch_size) losses = self._nested_gather(loss.repeat(batch_size))
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None: if logits is not None:
logits = self._pad_across_processes(logits)
logits = self._nested_gather(logits)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None: if labels is not None:
labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if losses_host is not None:
if not prediction_loss_only: losses = nested_numpify(losses_host)
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None losses_host, preds_host, labels_host = None, None, None
...@@ -1973,34 +2035,53 @@ class Trainer: ...@@ -1973,34 +2035,53 @@ class Trainer:
delattr(self, "_past") delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU # Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if losses_host is not None:
if not prediction_loss_only: losses = nested_numpify(losses_host)
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) if preds_host is not None:
logits = nested_numpify(preds_host)
eval_loss = eval_losses_gatherer.finalize() all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
preds = preds_gatherer.finalize() if not prediction_loss_only else None if labels_host is not None:
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) # Number of samples
if not isinstance(eval_dataset, IterableDataset):
num_samples = len(eval_dataset)
elif isinstance(eval_dataset, IterableDatasetShard):
num_samples = eval_dataset.num_examples
else:
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else: else:
metrics = {} metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors # To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics) metrics = denumpify_detensorize(metrics)
if eval_loss is not None: if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
# Prefix all keys with metric_key_prefix + '_' # Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
def _gather_and_numpify(self, tensors, name): def _nested_gather(self, tensors, name=None):
""" """
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered` concatenating them to `gathered`
...@@ -2008,13 +2089,47 @@ class Trainer: ...@@ -2008,13 +2089,47 @@ class Trainer:
if tensors is None: if tensors is None:
return return
if is_torch_tpu_available(): if is_torch_tpu_available():
if name is None:
name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors) tensors = smp_gather(tensors)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
tensors = distributed_concat(tensors) tensors = distributed_concat(tensors)
return tensors
return nested_numpify(tensors) # Copied from Accelerate.
def _pad_across_processes(self, tensor, pad_index=-100):
"""
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
they can safely be gathered.
"""
if isinstance(tensor, (list, tuple)):
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
if len(tensor.shape) < 2:
return tensor
# Gather all sizes
size = torch.tensor(tensor.shape, device=tensor.device)[None]
sizes = self._nested_gather(size).cpu()
max_size = max(s[1] for s in sizes)
if tensor.shape[1] == max_size:
return tensor
# Then pad to the maximum size
old_size = tensor.shape
new_size = list(old_size)
new_size[1] = max_size
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
new_tensor[:, : old_size[1]] = tensor
return new_tensor
def prediction_step( def prediction_step(
self, self,
...@@ -2131,3 +2246,148 @@ class Trainer: ...@@ -2131,3 +2246,148 @@ class Trainer:
return self.model.floating_point_ops(inputs) return self.model.floating_point_ops(inputs)
else: else:
return 0 return 0
#
# Deprecated code
#
def prediction_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels.
"""
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
# if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Batch size = {batch_size}")
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, self.args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
# a batch size to the sampler)
make_multiple_of = None
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval()
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
self._past = None
self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize() if not prediction_loss_only else None
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if eval_loss is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _gather_and_numpify(self, tensors, name):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
if tensors is None:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" """
Callbacks to use with the Trainer class and customize the training loop. Callbacks to use with the Trainer class and customize the training loop.
""" """
import collections
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
...@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback): ...@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero: if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
if self.prediction_bar is None: if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1) self.prediction_bar.update(1)
......
...@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100): ...@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
def find_batch_size(tensors):
"""
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
"""
if isinstance(tensors, (list, tuple)):
for t in tensors:
result = find_batch_size(t)
if result is not None:
return result
elif isinstance(tensors, dict):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
return result
elif isinstance(tensors, torch.Tensor):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
elif isinstance(tensors, np.ndarray):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
def nested_numpify(tensors): def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)." "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
...@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler): ...@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler):
""" """
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
warnings.warn(
"SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
...@@ -338,6 +362,10 @@ class DistributedTensorGatherer: ...@@ -338,6 +362,10 @@ class DistributedTensorGatherer:
""" """
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100): def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
warnings.warn(
"DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
self.world_size = world_size self.world_size = world_size
self.num_samples = num_samples self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
...@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler):
return iter(indices) return iter(indices)
class ShardSampler(Sampler):
"""
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which
shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14,
15]` for GPU-1.
The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on
GPU-1.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.total_batch_size = total_batch_size = batch_size * num_processes
num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
self.total_num_samples = num_batches * total_batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
# and it needs to be done several times.
while len(indices) < self.total_num_samples:
indices += indices[: (self.total_num_samples - len(indices))]
result = []
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
result += indices[batch_start : batch_start + self.batch_size]
return iter(result)
def __len__(self):
# Each shard only sees a fraction of total_num_samples.
return self.total_num_samples // self.num_processes
class IterableDatasetShard(IterableDataset): class IterableDatasetShard(IterableDataset):
""" """
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
...@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset):
self.process_index = process_index self.process_index = process_index
self.seed = seed self.seed = seed
self.epoch = 0 self.epoch = 0
self.num_examples = 0
def set_epoch(self, epoch): def set_epoch(self, epoch):
self.epoch = epoch self.epoch = epoch
...@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset):
self.dataset.set_epoch(epoch) self.dataset.set_epoch(epoch)
def __iter__(self): def __iter__(self):
self.num_examples = 0
if ( if (
not hasattr(self.dataset, "set_epoch") not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator") and hasattr(self.dataset, "generator")
...@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset):
first_batch = None first_batch = None
current_batch = [] current_batch = []
for element in self.dataset: for element in self.dataset:
self.num_examples += 1
current_batch.append(element) current_batch.append(element)
# Wait to have a full batch before yielding elements. # Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size: if len(current_batch) == real_batch_size:
......
...@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple): ...@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple):
label_ids: np.ndarray label_ids: np.ndarray
class EvalLoopOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
class PredictionOutput(NamedTuple): class PredictionOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]] predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray] label_ids: Optional[np.ndarray]
......
...@@ -524,6 +524,9 @@ class TrainingArguments: ...@@ -524,6 +524,9 @@ class TrainingArguments:
skip_memory_metrics: bool = field( skip_memory_metrics: bool = field(
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
) )
use_legacy_prediction_loop: bool = field(
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field( mp_parameters: str = field(
default="", default="",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import time import time
from typing import Optional from typing import Optional
...@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
return
if self.prediction_bar is None: if self.prediction_bar is None:
if self.training_tracker is not None: if self.training_tracker is not None:
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader)) self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
......
...@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
) )
self.assertEqual(len(dataset), 31) self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self): def test_training_iterable_dataset(self):
config = RegressionModelConfig() config = RegressionModelConfig()
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset() train_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2) args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset) trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train() trainer.train()
self.assertEqual(trainer.state.global_step, 4)
loader = trainer.get_train_dataloader() loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps def test_evaluation_iterable_dataset(self):
with self.assertRaises(ValueError): config = RegressionModelConfig(a=1.5, b=2.5)
args1 = RegressionTrainingArguments(output_dir="./examples") model = RegressionPreTrainedModel(config)
_ = Trainer(model=model, args=args1, train_dataset=train_dataset) eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
# Exception if eval_dataset is iterable in __init__ x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
with self.assertRaises(ValueError): pred = 1.5 * x + 2.5
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset) expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# Exception if predicting with iterable dataset # With a number of elements not a round multiple of the batch size
with self.assertRaises(ValueError): eval_dataset = SampleIterableDataset(length=66)
trainer.predict(train_dataset) results = trainer.evaluate(eval_dataset)
# Exception if evaluating with iterable dataset x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
with self.assertRaises(ValueError): pred = 1.5 * x + 2.5
trainer.evaluate(train_dataset) expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
preds = trainer.predict(trainer.eval_dataset).predictions
x = eval_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_num_train_epochs_in_training(self): def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given. # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import unittest import unittest
import numpy as np import numpy as np
...@@ -34,6 +35,7 @@ if is_torch_available(): ...@@ -34,6 +35,7 @@ if is_torch_available():
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
SequentialDistributedSampler, SequentialDistributedSampler,
ShardSampler,
get_parameter_names, get_parameter_names,
) )
...@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
# All shards have the same number of samples # All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0])) self.assertEqual(len(shard), len(shard_lists[0]))
for shard in shards:
# All shards know the total number of samples
self.assertEqual(shard.num_examples, len(reference))
observed = [] observed = []
for idx in range(0, len(shard_lists[0]), batch_size): for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists: for shard in shard_lists:
...@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
reference += reference reference += reference
self.assertListEqual(observed, reference[: len(observed)]) self.assertListEqual(observed, reference[: len(observed)])
# Check equivalence between IterableDataset and ShardSampler
dataset.generator.manual_seed(epoch)
reference = list(dataset)
sampler_shards = [
ShardSampler(
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard, sampler_shard in zip(shard_lists, sampler_shards):
self.assertListEqual(shard, list(sampler_shard))
def test_iterable_dataset_shard(self): def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset() dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
reference = copy.copy(dataset)
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_shard_sampler(self):
for n_elements in [64, 123]:
dataset = list(range(n_elements))
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment