Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
543d5693
Unverified
Commit
543d5693
authored
Nov 06, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Nov 06, 2020
Browse files
[fix] OSS tests - remove concurrent dist inits (#177)
parent
cc766aa5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
167 deletions
+185
-167
tests/optim/test_oss.py
tests/optim/test_oss.py
+185
-167
No files found.
tests/optim/test_oss.py
View file @
543d5693
...
...
@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import
os
import
tempfile
import
unittest
import
numpy
as
np
import
pytest
...
...
@@ -23,28 +25,27 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
def
setup_module
(
module
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29500"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
BACKEND
):
url
=
"file://"
+
tempfile_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
class
TestSingleRank
(
unittest
.
TestCase
):
"""
All the following tests do not check for inter-process communication
"""
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
def
setUp
(
self
):
dist_init
(
0
,
1
,
tempfile
.
mkstemp
()[
1
])
def
tearDown
(
self
):
torch
.
distributed
.
destroy_process_group
()
def
test_create
():
def
test_create
(
self
):
params
=
[
torch
.
rand
(
1
)]
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
def
test_state_dict
():
def
test_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
x
.
backward
()
...
...
@@ -88,8 +89,7 @@ def test_state_dict():
# Check that the exposed param_groups are on the proper device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
def
test_lr_scheduler
():
def
test_lr_scheduler
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x2
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
...
...
@@ -107,8 +107,7 @@ def test_lr_scheduler():
s2
.
step
()
assert
x
==
x2
def
test_step_with_kwargs
():
def
test_step_with_kwargs
(
self
):
class
SGDWithStepKWArg
(
torch
.
optim
.
SGD
):
def
step
(
self
,
closure
=
None
,
kwarg
=
[]):
super
().
step
()
...
...
@@ -122,8 +121,7 @@ def test_step_with_kwargs():
assert
kwarg
==
[
5
]
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_with_extra_inner_key
():
def
test_step_with_extra_inner_key
(
self
):
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
# Dummy optimizer which adds a new key to the param groups
def
step
(
self
,
closure
=
None
):
...
...
@@ -137,8 +135,7 @@ def test_step_with_extra_inner_key():
assert
o
.
param_groups
[
0
][
"new_key"
]
==
0.1
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_without_closure
():
def
test_step_without_closure
(
self
):
class
SGDWithoutClosure
(
torch
.
optim
.
SGD
):
def
step
(
self
):
return
super
().
step
()
...
...
@@ -149,8 +146,7 @@ def test_step_without_closure():
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_local_state_dict
():
def
test_local_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
local_state_dict
()
...
...
@@ -163,8 +159,7 @@ def test_local_state_dict():
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_implicit_local_state_dict
():
def
test_implicit_local_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
state_dict
()
...
...
@@ -178,8 +173,8 @@ def test_implicit_local_state_dict():
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
run_test_add_param_group
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_add_param_group
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
...
...
@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
assert
len
(
o
.
optim
.
param_groups
)
==
2
dist
.
destroy_process_group
()
def
test_add_param_group
():
world_size
=
3
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_zero_grad
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_zero_grad
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
x
=
torch
.
rand
(
1
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
...
...
@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert
not
m
.
weight
.
grad
assert
not
m
.
bias
.
grad
dist
.
destroy_process_group
()
def
test_zero_grad
():
world_size
=
2
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_step
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
"gloo"
)
x
=
torch
.
tensor
([
float
(
rank
+
1
)],
device
=
rank
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
m
.
weight
.
data
=
torch
.
tensor
([[
1.0
]])
...
...
@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_step
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step_with_closure
(
rank
,
world_size
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
)
def
run_test_step_with_closure
(
rank
,
world_size
,
tempfile_name
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
x_val
=
rank
+
1
weight
=
1.0
...
...
@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert
m
.
weight
==
torch
.
tensor
([[
1.1
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
2.1
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_step_with_closure
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_sharding
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_sharding
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
for
size
in
[
5
,
4
,
2
,
6
,
4
,
3
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
o
=
optim
.
OSS
(
params
,
lr
=
0.1
)
assert
sum
([
x
.
numel
()
for
x
in
o
.
optim
.
param_groups
[
0
][
"params"
]])
==
8
dist
.
destroy_process_group
()
def
test_sharding
():
world_size
=
3
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
):
dist_init
(
rank
,
world_size
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
...
...
@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optimizer_state_dict
)
dist
.
destroy_process_group
()
def
test_collect_shards
():
world_size
=
3
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
reference_rank
=
0
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_multiple_groups
(
rank
,
world_size
):
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
tempfile_name
=
tempfile_name
,
backend
=
"gloo"
)
sub_group_ranks
=
[
0
,
2
,
4
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
...
...
@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
check
(
optimizer
)
dist
.
destroy_process_group
(
process_group
)
dist
.
destroy_process_group
()
def
test_multiple_groups
():
world_size
=
6
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment