Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
88553373
Unverified
Commit
88553373
authored
Apr 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 03, 2021
Browse files
[fix] OSS - enforce cuda parameters for state consolidation if NCCL backend (#573)
parent
04001e76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
5 deletions
+13
-5
fairscale/optim/oss.py
fairscale/optim/oss.py
+8
-5
tests/optim/test_oss.py
tests/optim/test_oss.py
+5
-0
No files found.
fairscale/optim/oss.py
View file @
88553373
...
@@ -326,7 +326,10 @@ class OSS(Optimizer):
...
@@ -326,7 +326,10 @@ class OSS(Optimizer):
self
.
_all_states
=
[]
self
.
_all_states
=
[]
should_collect_state
=
self
.
rank
==
recipient_rank
or
recipient_rank
==
-
1
should_collect_state
=
self
.
rank
==
recipient_rank
or
recipient_rank
==
-
1
should_send_state
=
(
self
.
rank
!=
recipient_rank
and
recipient_rank
!=
-
1
)
or
recipient_rank
==
-
1
should_send_state
=
self
.
rank
!=
recipient_rank
# NCCL requires CUDA tensors for all communication primitives
dist_device
=
torch
.
device
(
"cuda"
)
if
self
.
backend
==
dist
.
Backend
.
NCCL
else
self
.
_default_device
for
rank
in
range
(
self
.
world_size
):
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
if
rank
==
self
.
rank
:
...
@@ -340,18 +343,18 @@ class OSS(Optimizer):
...
@@ -340,18 +343,18 @@ class OSS(Optimizer):
state_to_share
=
(
state_to_share
=
(
self
.
optim
.
state_dict
()
self
.
optim
.
state_dict
()
if
should_send_state
if
should_send_state
else
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_defaul
t_device
)
else
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
dis
t_device
)
)
)
broadcast_object
(
broadcast_object
(
state_to_share
,
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_defaul
t_device
,
state_to_share
,
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
dis
t_device
,
)
)
else
:
else
:
# Fetch the optim state from the other replicas
# Fetch the optim state from the other replicas
replica_state
=
broadcast_object
(
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_defaul
t_device
),
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
dis
t_device
),
src_rank
=
self
.
_local_to_global_rank
[
rank
],
src_rank
=
self
.
_local_to_global_rank
[
rank
],
group
=
self
.
group
,
group
=
self
.
group
,
dist_device
=
self
.
_defaul
t_device
,
dist_device
=
dis
t_device
,
)
)
if
should_collect_state
:
if
should_collect_state
:
...
...
tests/optim/test_oss.py
View file @
88553373
...
@@ -470,6 +470,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
...
@@ -470,6 +470,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
_
=
optimizer
.
step
(
closure
=
closure
)
_
=
optimizer
.
step
(
closure
=
closure
)
check_same_models_across_ranks
(
model
,
dist
.
group
.
WORLD
,
params_should_be_equal
=
True
,
check_broadcast_buffers
=
False
)
check_same_models_across_ranks
(
model
,
dist
.
group
.
WORLD
,
params_should_be_equal
=
True
,
check_broadcast_buffers
=
False
)
# Check that if the model is moved to cpu, the optimizer consolidation still works
model
.
cpu
()
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
optimizer
.
consolidate_state_dict
(
recipient_rank
=
reference_rank
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment