"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "496207a166213d9a54cb2dd13bbfcbe5089a2a1c"
Unverified Commit 0282e24e authored by Mansi Mane's avatar Mansi Mane Committed by GitHub
Browse files

Smmp batch not divisible by microbatches fix (#10778)



* Added debug prints

* Added config

* Added prints

* Added prints

* Added extra samples to SequentialDistributedSampler

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

* Added deubg prints

* Removed extra prints

* Making predicitons and labels multiple of batchsize

* updated number of microbatches

* Removed extra prints

* Made start_remainder similar to DistributedSamplerWithLoop

* Minor spacing update

* Added debug prints

Added config

Added prints

Added prints

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

Added extra samples to SequentialDistributedSampler

Added deubg prints

Removed extra prints

Making predicitons and labels multiple of batchsize

updated number of microbatches

Removed extra prints

Squashing redundant commits

* Made start_remainder similar to DistributedSamplerWithLoop

Minor spacing update

Made start_remainder similar to DistributedSamplerWithLoop

* Test and styling

* Rename test
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 40b049c7
......@@ -112,7 +112,12 @@ class SageMakerTrainer(Trainer):
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if self.is_model_parallel_enabled:
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
else:
return super()._get_eval_sampler(eval_dataset)
......
......@@ -1812,8 +1812,8 @@ class Trainer:
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
model.eval()
......
......@@ -220,7 +220,7 @@ class SequentialDistributedSampler(Sampler):
or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -232,8 +232,14 @@ class SequentialDistributedSampler(Sampler):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
num_samples = len(self.dataset)
# Add extra samples to make num_samples a multiple of batch_size if passed
if batch_size is not None:
self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
else:
self.num_samples = int(math.ceil(num_samples / num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.batch_size = batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))
......
......@@ -31,6 +31,7 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
get_parameter_names,
)
......@@ -167,3 +168,35 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
def test_sequential_distributed_sampler(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
# With a batch_size passed
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment