Unverified Commit d9c62047 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer support for IterableDataset for evaluation and predict (#11286)

* Bulk of the work

* Polish and tests

* Update QA Trainer

* Avoid breaking the predict method

* Deprecation warnings

* Store real eval dataloder

* Get eval dataset reference before wrap
parent e783ea73
This diff is collapsed.
......@@ -15,7 +15,7 @@
"""
Callbacks to use with the Trainer class and customize the training loop.
"""
import collections
import dataclasses
import json
from dataclasses import dataclass
......@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero:
if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
......
......@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
def find_batch_size(tensors):
"""
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
"""
if isinstance(tensors, (list, tuple)):
for t in tensors:
result = find_batch_size(t)
if result is not None:
return result
elif isinstance(tensors, dict):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
return result
elif isinstance(tensors, torch.Tensor):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
elif isinstance(tensors, np.ndarray):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
......@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler):
"""
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
warnings.warn(
"SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -338,6 +362,10 @@ class DistributedTensorGatherer:
"""
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
warnings.warn(
"DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
......@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler):
return iter(indices)
class ShardSampler(Sampler):
"""
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which
shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14,
15]` for GPU-1.
The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on
GPU-1.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.total_batch_size = total_batch_size = batch_size * num_processes
num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
self.total_num_samples = num_batches * total_batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
# and it needs to be done several times.
while len(indices) < self.total_num_samples:
indices += indices[: (self.total_num_samples - len(indices))]
result = []
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
result += indices[batch_start : batch_start + self.batch_size]
return iter(result)
def __len__(self):
# Each shard only sees a fraction of total_num_samples.
return self.total_num_samples // self.num_processes
class IterableDatasetShard(IterableDataset):
"""
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
......@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset):
self.process_index = process_index
self.seed = seed
self.epoch = 0
self.num_examples = 0
def set_epoch(self, epoch):
self.epoch = epoch
......@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset):
self.dataset.set_epoch(epoch)
def __iter__(self):
self.num_examples = 0
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
......@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset):
first_batch = None
current_batch = []
for element in self.dataset:
self.num_examples += 1
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size:
......
......@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple):
label_ids: np.ndarray
class EvalLoopOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
class PredictionOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
......
......@@ -524,6 +524,9 @@ class TrainingArguments:
skip_memory_metrics: bool = field(
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
use_legacy_prediction_loop: bool = field(
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import time
from typing import Optional
......@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
return
if self.prediction_bar is None:
if self.training_tracker is not None:
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
......
......@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
self.assertEqual(trainer.state.global_step, 4)
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
trainer.predict(train_dataset)
# With a number of elements not a round multiple of the batch size
eval_dataset = SampleIterableDataset(length=66)
results = trainer.evaluate(eval_dataset)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
trainer.evaluate(train_dataset)
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
preds = trainer.predict(trainer.eval_dataset).predictions
x = eval_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
......@@ -34,6 +35,7 @@ if is_torch_available():
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
)
......@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
for shard in shards:
# All shards know the total number of samples
self.assertEqual(shard.num_examples, len(reference))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
......@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
# Check equivalence between IterableDataset and ShardSampler
dataset.generator.manual_seed(epoch)
reference = list(dataset)
sampler_shards = [
ShardSampler(
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard, sampler_shard in zip(shard_lists, sampler_shards):
self.assertListEqual(shard, list(sampler_shard))
def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
reference = copy.copy(dataset)
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_shard_sampler(self):
for n_elements in [64, 123]:
dataset = list(range(n_elements))
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment