"src/vscode:/vscode.git/clone" did not exist on "aa2ce41b99a7990a8eb03bf2bf9253a40909b31e"
Unverified Commit 0f72988d authored by Vidush Vishwanath's avatar Vidush Vishwanath Committed by GitHub
Browse files

Specify num_replicas and rank when creating sampler (#216)

parent e1ad8803
...@@ -16,13 +16,17 @@ class DeepSpeedDataLoader(object): ...@@ -16,13 +16,17 @@ class DeepSpeedDataLoader(object):
tput_timer, tput_timer,
collate_fn=None, collate_fn=None,
num_local_io_workers=None, num_local_io_workers=None,
data_sampler=None): data_sampler=None,
data_parallel_world_size=None,
data_parallel_rank=None):
self.tput_timer = tput_timer self.tput_timer = tput_timer
self.batch_size = batch_size self.batch_size = batch_size
if local_rank >= 0: if local_rank >= 0:
if data_sampler is None: if data_sampler is None:
data_sampler = DistributedSampler(dataset) data_sampler = DistributedSampler(dataset=dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank)
device_count = 1 device_count = 1
else: else:
if data_sampler is None: if data_sampler is None:
......
...@@ -620,6 +620,13 @@ class DeepSpeedLight(Module): ...@@ -620,6 +620,13 @@ class DeepSpeedLight(Module):
if route == ROUTE_TRAIN: if route == ROUTE_TRAIN:
deepspeed_io_timer = self.tput_timer deepspeed_io_timer = self.tput_timer
# If mpu is provied, forward world size and parallel rank to sampler.
data_parallel_world_size = None
data_parallel_rank = None
if self.mpu is not None:
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
return DeepSpeedDataLoader(dataset=dataset, return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
...@@ -627,7 +634,9 @@ class DeepSpeedLight(Module): ...@@ -627,7 +634,9 @@ class DeepSpeedLight(Module):
local_rank=self.local_rank, local_rank=self.local_rank,
tput_timer=deepspeed_io_timer, tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers, num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler) data_sampler=data_sampler,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank)
def train(self): def train(self):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment