Unverified Commit d9c62047 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer support for IterableDataset for evaluation and predict (#11286)

* Bulk of the work

* Polish and tests

* Update QA Trainer

* Avoid breaking the predict method

* Deprecation warnings

* Store real eval dataloder

* Get eval dataset reference before wrap
parent e783ea73
This diff is collapsed.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" """
Callbacks to use with the Trainer class and customize the training loop. Callbacks to use with the Trainer class and customize the training loop.
""" """
import collections
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
...@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback): ...@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero: if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
if self.prediction_bar is None: if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1) self.prediction_bar.update(1)
......
...@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100): ...@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
def find_batch_size(tensors):
"""
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
"""
if isinstance(tensors, (list, tuple)):
for t in tensors:
result = find_batch_size(t)
if result is not None:
return result
elif isinstance(tensors, dict):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
return result
elif isinstance(tensors, torch.Tensor):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
elif isinstance(tensors, np.ndarray):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
def nested_numpify(tensors): def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)." "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
...@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler): ...@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler):
""" """
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
warnings.warn(
"SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
...@@ -338,6 +362,10 @@ class DistributedTensorGatherer: ...@@ -338,6 +362,10 @@ class DistributedTensorGatherer:
""" """
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100): def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
warnings.warn(
"DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
self.world_size = world_size self.world_size = world_size
self.num_samples = num_samples self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
...@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler):
return iter(indices) return iter(indices)
class ShardSampler(Sampler):
"""
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which
shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14,
15]` for GPU-1.
The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on
GPU-1.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.total_batch_size = total_batch_size = batch_size * num_processes
num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
self.total_num_samples = num_batches * total_batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
# and it needs to be done several times.
while len(indices) < self.total_num_samples:
indices += indices[: (self.total_num_samples - len(indices))]
result = []
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
result += indices[batch_start : batch_start + self.batch_size]
return iter(result)
def __len__(self):
# Each shard only sees a fraction of total_num_samples.
return self.total_num_samples // self.num_processes
class IterableDatasetShard(IterableDataset): class IterableDatasetShard(IterableDataset):
""" """
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
...@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset):
self.process_index = process_index self.process_index = process_index
self.seed = seed self.seed = seed
self.epoch = 0 self.epoch = 0
self.num_examples = 0
def set_epoch(self, epoch): def set_epoch(self, epoch):
self.epoch = epoch self.epoch = epoch
...@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset):
self.dataset.set_epoch(epoch) self.dataset.set_epoch(epoch)
def __iter__(self): def __iter__(self):
self.num_examples = 0
if ( if (
not hasattr(self.dataset, "set_epoch") not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator") and hasattr(self.dataset, "generator")
...@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset):
first_batch = None first_batch = None
current_batch = [] current_batch = []
for element in self.dataset: for element in self.dataset:
self.num_examples += 1
current_batch.append(element) current_batch.append(element)
# Wait to have a full batch before yielding elements. # Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size: if len(current_batch) == real_batch_size:
......
...@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple): ...@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple):
label_ids: np.ndarray label_ids: np.ndarray
class EvalLoopOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
class PredictionOutput(NamedTuple): class PredictionOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]] predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray] label_ids: Optional[np.ndarray]
......
...@@ -524,6 +524,9 @@ class TrainingArguments: ...@@ -524,6 +524,9 @@ class TrainingArguments:
skip_memory_metrics: bool = field( skip_memory_metrics: bool = field(
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
) )
use_legacy_prediction_loop: bool = field(
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field( mp_parameters: str = field(
default="", default="",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import time import time
from typing import Optional from typing import Optional
...@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
return
if self.prediction_bar is None: if self.prediction_bar is None:
if self.training_tracker is not None: if self.training_tracker is not None:
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader)) self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
......
...@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
) )
self.assertEqual(len(dataset), 31) self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self): def test_training_iterable_dataset(self):
config = RegressionModelConfig() config = RegressionModelConfig()
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset() train_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2) args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset) trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train() trainer.train()
self.assertEqual(trainer.state.global_step, 4)
loader = trainer.get_train_dataloader() loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps def test_evaluation_iterable_dataset(self):
with self.assertRaises(ValueError): config = RegressionModelConfig(a=1.5, b=2.5)
args1 = RegressionTrainingArguments(output_dir="./examples") model = RegressionPreTrainedModel(config)
_ = Trainer(model=model, args=args1, train_dataset=train_dataset) eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
# Exception if eval_dataset is iterable in __init__ x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
with self.assertRaises(ValueError): pred = 1.5 * x + 2.5
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset) expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# Exception if predicting with iterable dataset # With a number of elements not a round multiple of the batch size
with self.assertRaises(ValueError): eval_dataset = SampleIterableDataset(length=66)
trainer.predict(train_dataset) results = trainer.evaluate(eval_dataset)
# Exception if evaluating with iterable dataset x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
with self.assertRaises(ValueError): pred = 1.5 * x + 2.5
trainer.evaluate(train_dataset) expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
preds = trainer.predict(trainer.eval_dataset).predictions
x = eval_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_num_train_epochs_in_training(self): def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given. # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import unittest import unittest
import numpy as np import numpy as np
...@@ -34,6 +35,7 @@ if is_torch_available(): ...@@ -34,6 +35,7 @@ if is_torch_available():
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
SequentialDistributedSampler, SequentialDistributedSampler,
ShardSampler,
get_parameter_names, get_parameter_names,
) )
...@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
# All shards have the same number of samples # All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0])) self.assertEqual(len(shard), len(shard_lists[0]))
for shard in shards:
# All shards know the total number of samples
self.assertEqual(shard.num_examples, len(reference))
observed = [] observed = []
for idx in range(0, len(shard_lists[0]), batch_size): for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists: for shard in shard_lists:
...@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
reference += reference reference += reference
self.assertListEqual(observed, reference[: len(observed)]) self.assertListEqual(observed, reference[: len(observed)])
# Check equivalence between IterableDataset and ShardSampler
dataset.generator.manual_seed(epoch)
reference = list(dataset)
sampler_shards = [
ShardSampler(
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard, sampler_shard in zip(shard_lists, sampler_shards):
self.assertListEqual(shard, list(sampler_shard))
def test_iterable_dataset_shard(self): def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset() dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
reference = copy.copy(dataset)
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_shard_sampler(self):
for n_elements in [64, 123]:
dataset = list(range(n_elements))
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment