Commit fe1ea898 authored by mohammad's avatar mohammad
Browse files

Merge branch 'master' into master_params_sharing

parents b1ac9fd3 3ee811be
...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch (sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.""" sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last) drop_last)
if rank == -1: if rank == -1:
...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self.wrap_around = 0 self.wrap_around = 0
self.wrap_last = wrap_last self.wrap_last = wrap_last
self.start_iter = 0 self.start_iter = 0
self.interleave = interleave
def __iter__(self): def __iter__(self):
batch = [] batch = []
...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment