Commit 3ee811be authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'batch_sampler_update' into 'master'

batch sampler interleave flag

See merge request ADLR/megatron-lm!81
parents ad299490 5cb0c807
...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch (sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.""" sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last) drop_last)
if rank == -1: if rank == -1:
...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self.wrap_around = 0 self.wrap_around = 0
self.wrap_last = wrap_last self.wrap_last = wrap_last
self.start_iter = 0 self.start_iter = 0
self.interleave = interleave
def __iter__(self): def __iter__(self):
batch = [] batch = []
...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment