Unverified Commit a0a027c2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add DistributedSamplerWithLoop (#10746)

* Add DistributedSamplerWithLoop

* Fix typo

* Test and small fix
parent 14492222
......@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
from ..trainer import Trainer
from ..trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
SequentialDistributedSampler,
nested_detach,
nested_numpify,
......@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
)
elif not self.args.dataloader_drop_last:
return DistributedSamplerWithLoop(
self.train_dataset,
self.args.per_device_train_batch_size,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
)
else:
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
else:
......
......@@ -77,6 +77,7 @@ from .trainer_callback import (
)
from .trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
......@@ -491,24 +492,10 @@ class Trainer:
):
return None
# Gather the number of processes and this process index.
if self.args.parallel_mode == ParallelMode.TPU:
num_processes = xm.xrt_world_size()
process_index = xm.get_ordinal()
elif (
self.args.parallel_mode == ParallelMode.DISTRIBUTED
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
):
num_processes = dist.get_world_size()
process_index = dist.get_rank()
else:
num_processes = 1
process_index = 0
# Build the sampler.
if self.args.group_by_length:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if num_processes <= 1:
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
)
......@@ -516,16 +503,26 @@ class Trainer:
return DistributedLengthGroupedSampler(
self.train_dataset,
self.args.train_batch_size,
num_replicas=num_processes,
rank=process_index,
num_replicas=self.args.world_size,
rank=self.args.process_index,
model_input_name=model_input_name,
)
else:
if num_processes <= 1:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
)
else:
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
)
def get_train_dataloader(self) -> DataLoader:
"""
......
......@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier()
class DistributedSamplerWithLoop(DistributedSampler):
"""
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
shuffled samples to make each process have a round multiple of batch_size samples.
Args:
dataset (:obj:`torch.utils.data.Dataset`):
Dataset used for sampling.
batch_size (:obj:`int`):
The batch size used with this sampler
kwargs:
All other keyword arguments passed to :obj:`DistributedSampler`.
"""
def __init__(self, dataset, batch_size, **kwargs):
super().__init__(dataset, **kwargs)
self.batch_size = batch_size
def __iter__(self):
indices = list(super().__iter__())
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
# of the world size, so we skip those.
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
indices += indices[start_remainder : start_remainder + remainder]
return iter(indices)
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
......@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
......
......@@ -742,6 +742,20 @@ class TrainingArguments:
return torch.distributed.get_world_size()
return 1
@property
@torch_required
def process_index(self):
"""
The number of processes used in parallel.
"""
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_distributed_available():
return sm_dist.get_rank()
elif self.local_rank != -1:
return torch.distributed.get_rank()
return 0
@property
def place_model_on_device(self):
"""
......
......@@ -27,6 +27,7 @@ if is_torch_available():
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
......@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
)
# fmt: on
def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0)
shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1)
# Set seeds
shard1.set_epoch(0)
shard2.set_epoch(0)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = []
for sample1, sample2 in zip(samples1, samples2):
total += [sample1, sample2]
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment