Unverified Commit 0f72988d authored by Vidush Vishwanath's avatar Vidush Vishwanath Committed by GitHub
Browse files

Specify num_replicas and rank when creating sampler (#216)

parent e1ad8803
......@@ -16,13 +16,17 @@ class DeepSpeedDataLoader(object):
tput_timer,
collate_fn=None,
num_local_io_workers=None,
data_sampler=None):
data_sampler=None,
data_parallel_world_size=None,
data_parallel_rank=None):
self.tput_timer = tput_timer
self.batch_size = batch_size
if local_rank >= 0:
if data_sampler is None:
data_sampler = DistributedSampler(dataset)
data_sampler = DistributedSampler(dataset=dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank)
device_count = 1
else:
if data_sampler is None:
......
......@@ -620,6 +620,13 @@ class DeepSpeedLight(Module):
if route == ROUTE_TRAIN:
deepspeed_io_timer = self.tput_timer
# If mpu is provied, forward world size and parallel rank to sampler.
data_parallel_world_size = None
data_parallel_rank = None
if self.mpu is not None:
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
......@@ -627,7 +634,9 @@ class DeepSpeedLight(Module):
local_rank=self.local_rank,
tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler)
data_sampler=data_sampler,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank)
def train(self):
r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment