Commit 26d6aff4 authored by Raul Puri's avatar Raul Puri
Browse files

batch sampler interleave flag

parent 6106127c
...@@ -80,10 +80,17 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -80,10 +80,17 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch (sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.""" sampler.
The `interleave` argument specifies how to distribute a batch.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last) drop_last)
if rank == -1: if rank == -1:
...@@ -95,6 +102,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -95,6 +102,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self.wrap_around = 0 self.wrap_around = 0
self.wrap_last = wrap_last self.wrap_last = wrap_last
self.start_iter = 0 self.start_iter = 0
self.interleave = interleave
def __iter__(self): def __iter__(self):
batch = [] batch = []
...@@ -130,6 +138,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -130,6 +138,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment