"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e631383d4fb55ed2dff53f2b80a29e51cd2f563d"
Unverified Commit aaaed56f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer iterable dataset (#11254)



* IterableDatasetShard

* Test and integration in Trainer

* Update src/transformers/trainer_pt_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Style
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 83206ca6
......@@ -81,6 +81,7 @@ from .trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
......@@ -493,9 +494,7 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
# Build the sampler.
......@@ -553,6 +552,26 @@ class Trainer:
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
self.train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
else:
train_dataset = self.train_dataset
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
train_sampler = self._get_train_sampler()
return DataLoader(
......
......@@ -28,7 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
import numpy as np
import torch
from packaging import version
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
......@@ -576,6 +576,96 @@ class DistributedLengthGroupedSampler(DistributedSampler):
return iter(indices)
class IterableDatasetShard(IterableDataset):
"""
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
will always yield a number of samples that is a round multiple of the actual batch size (which is :obj:`batch_size
x num_processes`). Depending on the value of the :obj:`drop_last` attribute, it will either stop the iteration at
the first batch that would be too small or loop with indices from the beginning.
On two processes with an iterable dataset yielding of :obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch
size of 2:
- the shard on process 0 will yield :obj:`[0, 1, 4, 5, 8, 9]` so will see batches :obj:`[0, 1]`, :obj:`[4, 5]`,
:obj:`[8, 9]`
- the shard on process 1 will yield :obj:`[2, 3, 6, 7, 10, 11]` so will see batches :obj:`[2, 3]`, :obj:`[6, 7]`,
:obj:`[10, 11]`
.. warning:
If your IterableDataset implements some randomization that needs to be applied the same way on all processes
(for instance, a shuffling), you should use a :obj:`torch.Generator` in a :obj:`generator` attribute of the
:obj:`dataset` to generate your random numbers and call the
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom
logic.
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard.
drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
beginning.
num_processes (:obj:`int`, `optional`, defaults to 1):
The number of processes running concurrently.
process_index (:obj:`int`, `optional`, defaults to 0):
The index of the current process.
seed (:obj:`int`, `optional`, defaults to 0):
A random seed that will be used for the random number generation in
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch`.
"""
def __init__(
self,
dataset: IterableDataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
seed: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.seed = seed
self.epoch = 0
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator):
self.dataset.generator.manual_seed(self.seed + self.epoch)
real_batch_size = self.batch_size * self.num_processes
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
first_batch = None
current_batch = []
for element in self.dataset:
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size:
for i in process_slice:
yield current_batch[i]
if first_batch is None:
first_batch = current_batch.copy()
current_batch = []
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
if not self.drop_last and len(current_batch) > 0:
if first_batch is None:
first_batch = current_batch.copy()
while len(current_batch) < real_batch_size:
current_batch += first_batch
for i in process_slice:
yield current_batch[i]
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here
......
......@@ -44,9 +44,7 @@ if is_torch_available():
from torch.utils.data import IterableDataset
from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset,
GlueDataTrainingArguments,
......@@ -54,7 +52,6 @@ if is_torch_available():
GPT2LMHeadModel,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Trainer,
TrainerState,
)
......@@ -138,16 +135,12 @@ class RegressionModelConfig(PretrainedConfig):
if is_torch_available():
class SampleIterableDataset(IterableDataset):
"""
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""
def __init__(self, file_path, tokenizer):
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
def __iter__(self):
for i in range(len(self.ds)):
yield self.ds[i]
for i in range(len(self.dataset)):
yield self.dataset[i]
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False):
......@@ -827,18 +820,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
loader = trainer.get_train_dataloader()
......@@ -847,30 +834,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)
def test_num_train_epochs_in_training(self):
......
......@@ -23,12 +23,14 @@ from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from torch.utils.data import IterableDataset
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
......@@ -49,6 +51,22 @@ if is_torch_available():
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
class RandomIterableDataset(IterableDataset):
# For testing, an iterable dataset of random length
def __init__(self, p_stop=0.01, max_length=1000):
self.p_stop = p_stop
self.max_length = max_length
self.generator = torch.Generator()
def __iter__(self):
count = 0
stop = False
while not stop and count < self.max_length:
yield count
count += 1
number = torch.rand(1, generator=self.generator).item()
stop = number < self.p_stop
@require_torch
class TrainerUtilsTest(unittest.TestCase):
......@@ -243,3 +261,45 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0):
# Set the seed for the base dataset to get the proper reference.
dataset.generator.manual_seed(epoch)
reference = list(dataset)
shards = [
IterableDatasetShard(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard in shards:
shard.set_epoch(epoch)
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment