Unverified Commit a0a027c2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add DistributedSamplerWithLoop (#10746)

* Add DistributedSamplerWithLoop

* Fix typo

* Test and small fix
parent 14492222
...@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model ...@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
from ..trainer import Trainer from ..trainer import Trainer
from ..trainer_pt_utils import ( from ..trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
SequentialDistributedSampler, SequentialDistributedSampler,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
...@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer): ...@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
return DistributedLengthGroupedSampler( return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank() self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
) )
elif not self.args.dataloader_drop_last:
return DistributedSamplerWithLoop(
self.train_dataset,
self.args.per_device_train_batch_size,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
)
else: else:
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
else: else:
......
...@@ -77,6 +77,7 @@ from .trainer_callback import ( ...@@ -77,6 +77,7 @@ from .trainer_callback import (
) )
from .trainer_pt_utils import ( from .trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
...@@ -491,24 +492,10 @@ class Trainer: ...@@ -491,24 +492,10 @@ class Trainer:
): ):
return None return None
# Gather the number of processes and this process index.
if self.args.parallel_mode == ParallelMode.TPU:
num_processes = xm.xrt_world_size()
process_index = xm.get_ordinal()
elif (
self.args.parallel_mode == ParallelMode.DISTRIBUTED
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
):
num_processes = dist.get_world_size()
process_index = dist.get_rank()
else:
num_processes = 1
process_index = 0
# Build the sampler. # Build the sampler.
if self.args.group_by_length: if self.args.group_by_length:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if num_processes <= 1: if self.args.world_size <= 1:
return LengthGroupedSampler( return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
) )
...@@ -516,16 +503,26 @@ class Trainer: ...@@ -516,16 +503,26 @@ class Trainer:
return DistributedLengthGroupedSampler( return DistributedLengthGroupedSampler(
self.train_dataset, self.train_dataset,
self.args.train_batch_size, self.args.train_batch_size,
num_replicas=num_processes, num_replicas=self.args.world_size,
rank=process_index, rank=self.args.process_index,
model_input_name=model_input_name, model_input_name=model_input_name,
) )
else: else:
if num_processes <= 1: if self.args.world_size <= 1:
return RandomSampler(self.train_dataset) return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
)
else: else:
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index) return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
""" """
......
...@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier() dist.barrier()
class DistributedSamplerWithLoop(DistributedSampler):
"""
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
shuffled samples to make each process have a round multiple of batch_size samples.
Args:
dataset (:obj:`torch.utils.data.Dataset`):
Dataset used for sampling.
batch_size (:obj:`int`):
The batch size used with this sampler
kwargs:
All other keyword arguments passed to :obj:`DistributedSampler`.
"""
def __init__(self, dataset, batch_size, **kwargs):
super().__init__(dataset, **kwargs)
self.batch_size = batch_size
def __iter__(self):
indices = list(super().__iter__())
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
# of the world size, so we skip those.
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
indices += indices[start_remainder : start_remainder + remainder]
return iter(indices)
class SequentialDistributedSampler(Sampler): class SequentialDistributedSampler(Sampler):
""" """
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
...@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler): ...@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
if xm.xrt_world_size() <= 1: if xm.xrt_world_size() <= 1:
return RandomSampler(dataset) return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
......
...@@ -742,6 +742,20 @@ class TrainingArguments: ...@@ -742,6 +742,20 @@ class TrainingArguments:
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
return 1 return 1
@property
@torch_required
def process_index(self):
"""
The number of processes used in parallel.
"""
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_distributed_available():
return sm_dist.get_rank()
elif self.local_rank != -1:
return torch.distributed.get_rank()
return 0
@property @property
def place_model_on_device(self): def place_model_on_device(self):
""" """
......
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import ( from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
...@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias'] ['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
) )
# fmt: on # fmt: on
def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0)
shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1)
# Set seeds
shard1.set_epoch(0)
shard2.set_epoch(0)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = []
for sample1, sample2 in zip(samples1, samples2):
total += [sample1, sample2]
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment