Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
0f72988d
Unverified
Commit
0f72988d
authored
Jun 04, 2020
by
Vidush Vishwanath
Committed by
GitHub
Jun 04, 2020
Browse files
Specify num_replicas and rank when creating sampler (#216)
parent
e1ad8803
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
3 deletions
+16
-3
deepspeed/pt/deepspeed_dataloader.py
deepspeed/pt/deepspeed_dataloader.py
+6
-2
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+10
-1
No files found.
deepspeed/pt/deepspeed_dataloader.py
View file @
0f72988d
...
...
@@ -16,13 +16,17 @@ class DeepSpeedDataLoader(object):
tput_timer
,
collate_fn
=
None
,
num_local_io_workers
=
None
,
data_sampler
=
None
):
data_sampler
=
None
,
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
):
self
.
tput_timer
=
tput_timer
self
.
batch_size
=
batch_size
if
local_rank
>=
0
:
if
data_sampler
is
None
:
data_sampler
=
DistributedSampler
(
dataset
)
data_sampler
=
DistributedSampler
(
dataset
=
dataset
,
num_replicas
=
data_parallel_world_size
,
rank
=
data_parallel_rank
)
device_count
=
1
else
:
if
data_sampler
is
None
:
...
...
deepspeed/pt/deepspeed_light.py
View file @
0f72988d
...
...
@@ -620,6 +620,13 @@ class DeepSpeedLight(Module):
if
route
==
ROUTE_TRAIN
:
deepspeed_io_timer
=
self
.
tput_timer
# If mpu is provied, forward world size and parallel rank to sampler.
data_parallel_world_size
=
None
data_parallel_rank
=
None
if
self
.
mpu
is
not
None
:
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
return
DeepSpeedDataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
pin_memory
=
pin_memory
,
...
...
@@ -627,7 +634,9 @@ class DeepSpeedLight(Module):
local_rank
=
self
.
local_rank
,
tput_timer
=
deepspeed_io_timer
,
num_local_io_workers
=
num_local_io_workers
,
data_sampler
=
data_sampler
)
data_sampler
=
data_sampler
,
data_parallel_world_size
=
data_parallel_world_size
,
data_parallel_rank
=
data_parallel_rank
)
def
train
(
self
):
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment