Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
543d5693
Unverified
Commit
543d5693
authored
Nov 06, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Nov 06, 2020
Browse files
[fix] OSS tests - remove concurrent dist inits (#177)
parent
cc766aa5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
167 deletions
+185
-167
tests/optim/test_oss.py
tests/optim/test_oss.py
+185
-167
No files found.
tests/optim/test_oss.py
View file @
543d5693
...
@@ -7,7 +7,9 @@
...
@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
import
os
import
tempfile
import
unittest
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -23,163 +25,156 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
...
@@ -23,163 +25,156 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
def
setup_module
(
module
):
def
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
BACKEND
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
url
=
"file://"
+
tempfile_name
os
.
environ
[
"MASTER_PORT"
]
=
"29500"
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
def
test_create
():
class
TestSingleRank
(
unittest
.
TestCase
):
params
=
[
torch
.
rand
(
1
)]
"""
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
All the following tests do not check for inter-process communication
"""
def
setUp
(
self
):
dist_init
(
0
,
1
,
tempfile
.
mkstemp
()[
1
])
def
test_state_dict
():
def
tearDown
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
torch
.
distributed
.
destroy_process_group
()
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
o
.
zero_grad
()
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
state_dict
=
o
.
state_dict
()
# Check that the state dict is pytorch-compliant key wise
assert
"param_groups"
in
state_dict
.
keys
()
assert
"state"
in
state_dict
.
keys
()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert
state_dict
[
"param_groups"
][
0
][
"lr"
]
==
0.1
assert
state_dict
[
"param_groups"
][
0
][
"momentum"
]
==
0.9
assert
not
state_dict
[
"param_groups"
][
0
][
"nesterov"
]
assert
state_dict
[
"param_groups"
][
0
][
"weight_decay"
]
==
0.0
assert
state_dict
[
"param_groups"
][
0
][
"dampening"
]
==
0.0
# Check that the pulled state and the .param_groups attribute are in sync
for
k
in
state_dict
[
"param_groups"
][
0
].
keys
():
if
k
!=
"params"
:
assert
state_dict
[
"param_groups"
][
0
][
k
]
==
o
.
param_groups
[
0
][
k
]
# Check that it's correctly loaded
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_state_dict
(
state_dict
)
# Check that state is correct and on proper device
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.71
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.9
],
device
=
DEVICE
)
# Check that the exposed param_groups are on the proper device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
def
test_create
(
self
):
params
=
[
torch
.
rand
(
1
)]
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
def
test_lr_scheduler
():
def
test_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x2
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o2
=
torch
.
optim
.
SGD
([
x2
],
lr
=
0.01
)
s
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
o
,
1
)
s2
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
o2
,
1
)
for
_
in
range
(
5
):
x
.
backward
()
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
o
.
zero_grad
()
o
.
zero_grad
()
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
state_dict
=
o
.
state_dict
()
# Check that the state dict is pytorch-compliant key wise
assert
"param_groups"
in
state_dict
.
keys
()
assert
"state"
in
state_dict
.
keys
()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert
state_dict
[
"param_groups"
][
0
][
"lr"
]
==
0.1
assert
state_dict
[
"param_groups"
][
0
][
"momentum"
]
==
0.9
assert
not
state_dict
[
"param_groups"
][
0
][
"nesterov"
]
assert
state_dict
[
"param_groups"
][
0
][
"weight_decay"
]
==
0.0
assert
state_dict
[
"param_groups"
][
0
][
"dampening"
]
==
0.0
# Check that the pulled state and the .param_groups attribute are in sync
for
k
in
state_dict
[
"param_groups"
][
0
].
keys
():
if
k
!=
"params"
:
assert
state_dict
[
"param_groups"
][
0
][
k
]
==
o
.
param_groups
[
0
][
k
]
# Check that it's correctly loaded
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_state_dict
(
state_dict
)
# Check that state is correct and on proper device
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
o
.
step
()
o
.
step
()
s
.
step
()
assert
x
==
torch
.
tensor
([
0.71
],
device
=
DEVICE
)
x2
.
backward
()
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.9
],
device
=
DEVICE
)
o2
.
zero_grad
()
o2
.
step
()
# Check that the exposed param_groups are on the proper device
s2
.
step
()
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
assert
x
==
x2
def
test_lr_scheduler
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
def
test_step_with_kwargs
():
x2
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
class
SGDWithStepKWArg
(
torch
.
optim
.
SGD
):
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
def
step
(
self
,
closure
=
None
,
kwarg
=
[]):
o2
=
torch
.
optim
.
SGD
([
x2
],
lr
=
0.01
)
super
().
step
()
s
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
o
,
1
)
kwarg
.
append
(
5
)
s2
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
o2
,
1
)
for
_
in
range
(
5
):
kwarg
=
[]
x
.
backward
()
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
.
zero_grad
()
o
=
optim
.
OSS
([
x
],
SGDWithStepKWArg
,
lr
=
0.1
)
o
.
step
()
x
.
backward
()
s
.
step
()
o
.
step
(
0
,
kwarg
=
kwarg
)
x2
.
backward
()
assert
kwarg
==
[
5
]
o2
.
zero_grad
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
o2
.
step
()
s2
.
step
()
assert
x
==
x2
def
test_step_with_extra_inner_key
():
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
def
test_step_with_kwargs
(
self
):
# Dummy optimizer which adds a new key to the param groups
class
SGDWithStepKWArg
(
torch
.
optim
.
SGD
):
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
,
kwarg
=
[]):
super
().
step
()
super
().
step
()
self
.
param_groups
[
0
][
"new_key"
]
=
0.1
kwarg
.
append
(
5
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
kwarg
=
[]
o
=
optim
.
OSS
([
x
],
SGDWithNewKey
,
lr
=
0.1
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x
.
backward
()
o
=
optim
.
OSS
([
x
],
SGDWithStepKWArg
,
lr
=
0.1
)
o
.
step
()
x
.
backward
()
assert
o
.
param_groups
[
0
][
"new_key"
]
==
0.1
o
.
step
(
0
,
kwarg
=
kwarg
)
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
kwarg
==
[
5
]
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_with_extra_inner_key
(
self
):
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
# Dummy optimizer which adds a new key to the param groups
def
step
(
self
,
closure
=
None
):
super
().
step
()
self
.
param_groups
[
0
][
"new_key"
]
=
0.1
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
SGDWithNewKey
,
lr
=
0.1
)
x
.
backward
()
o
.
step
()
assert
o
.
param_groups
[
0
][
"new_key"
]
==
0.1
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_without_closure
():
def
test_step_without_closure
(
self
):
class
SGDWithoutClosure
(
torch
.
optim
.
SGD
):
class
SGDWithoutClosure
(
torch
.
optim
.
SGD
):
def
step
(
self
):
def
step
(
self
):
return
super
().
step
()
return
super
().
step
()
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
SGDWithoutClosure
,
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
SGDWithoutClosure
,
lr
=
0.1
)
x
.
backward
()
x
.
backward
()
o
.
step
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_local_state_dict
(
self
):
def
test_local_state_dict
():
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
local_state_dict
()
local_state_dict
=
o
.
local_state_dict
()
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_local_state_dict
(
local_state_dict
)
o
.
load_local_state_dict
(
local_state_dict
)
# We should now be using a lr of 0.1.
# We should now be using a lr of 0.1.
assert
o
.
optim
.
param_groups
[
0
][
"lr"
]
==
0.1
assert
o
.
optim
.
param_groups
[
0
][
"lr"
]
==
0.1
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
x
.
backward
()
o
.
step
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_implicit_local_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
def
test_implicit_local_state_dict
():
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
local_state_dict
=
o
.
state_dict
()
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
local_state_dict
=
o
.
state_dict
()
o
.
load_state_dict
(
local_state_dict
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
# We should now be using a lr of 0.1.
o
.
load_state_dict
(
local_state_dict
)
assert
o
.
optim
.
param_groups
[
0
][
"lr"
]
==
0.1
# We should now be using a lr of 0.1.
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
assert
o
.
optim
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
o
.
step
()
x
.
backward
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
run_test_add_param_group
(
rank
,
world_size
):
def
run_test_add_param_group
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
params
=
[]
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
params
.
append
(
torch
.
rand
(
size
,
1
))
...
@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
...
@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
assert
len
(
o
.
optim
.
param_groups
)
==
2
assert
len
(
o
.
optim
.
param_groups
)
==
2
dist
.
destroy_process_group
()
def
test_add_param_group
():
def
test_add_param_group
():
world_size
=
3
world_size
=
3
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_zero_grad
(
rank
,
world_size
):
def
run_test_zero_grad
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
...
@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
...
@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert
not
m
.
weight
.
grad
assert
not
m
.
weight
.
grad
assert
not
m
.
bias
.
grad
assert
not
m
.
bias
.
grad
dist
.
destroy_process_group
()
def
test_zero_grad
():
def
test_zero_grad
():
world_size
=
2
world_size
=
2
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step
(
rank
,
world_size
):
def
run_test_step
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
"gloo"
)
x
=
torch
.
tensor
([
float
(
rank
+
1
)],
device
=
rank
)
x
=
torch
.
tensor
([
float
(
rank
+
1
)],
device
=
rank
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
m
.
weight
.
data
=
torch
.
tensor
([[
1.0
]])
m
.
weight
.
data
=
torch
.
tensor
([[
1.0
]])
...
@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
...
@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_no_cuda
def
test_step
():
def
test_step
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step_with_closure
(
rank
,
world_size
,
optimizer
=
None
):
def
run_test_step_with_closure
(
rank
,
world_size
,
tempfile_name
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
x_val
=
rank
+
1
x_val
=
rank
+
1
weight
=
1.0
weight
=
1.0
...
@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
...
@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert
m
.
weight
==
torch
.
tensor
([[
1.1
]],
device
=
rank
)
assert
m
.
weight
==
torch
.
tensor
([[
1.1
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
2.1
],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
2.1
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_no_cuda
def
test_step_with_closure
():
def
test_step_with_closure
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_sharding
(
rank
,
world_size
):
def
run_test_sharding
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
params
=
[]
for
size
in
[
5
,
4
,
2
,
6
,
4
,
3
]:
for
size
in
[
5
,
4
,
2
,
6
,
4
,
3
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
params
.
append
(
torch
.
rand
(
size
,
1
))
o
=
optim
.
OSS
(
params
,
lr
=
0.1
)
o
=
optim
.
OSS
(
params
,
lr
=
0.1
)
assert
sum
([
x
.
numel
()
for
x
in
o
.
optim
.
param_groups
[
0
][
"params"
]])
==
8
assert
sum
([
x
.
numel
()
for
x
in
o
.
optim
.
param_groups
[
0
][
"params"
]])
==
8
dist
.
destroy_process_group
()
def
test_sharding
():
def
test_sharding
():
world_size
=
3
world_size
=
3
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
):
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
# Run a dummy step so that the optimizer state dict exists
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
...
@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
...
@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optimizer_state_dict
)
optimizer
.
load_state_dict
(
optimizer_state_dict
)
dist
.
destroy_process_group
()
def
test_collect_shards
():
def
test_collect_shards
():
world_size
=
3
world_size
=
3
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
reference_rank
=
0
reference_rank
=
0
mp
.
spawn
(
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
def
run_test_multiple_groups
(
rank
,
world_size
):
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
# Only work with the even ranks, to check that the global_rank indexing is properly used
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
tempfile_name
=
tempfile_name
,
backend
=
"gloo"
)
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
sub_group_ranks
=
[
0
,
2
,
4
]
sub_group_ranks
=
[
0
,
2
,
4
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
...
@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
...
@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
check
(
optimizer
)
check
(
optimizer
)
dist
.
destroy_process_group
(
process_group
)
dist
.
destroy_process_group
()
def
test_multiple_groups
():
def
test_multiple_groups
():
world_size
=
6
world_size
=
6
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment