Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
543d5693
Unverified
Commit
543d5693
authored
Nov 06, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Nov 06, 2020
Browse files
[fix] OSS tests - remove concurrent dist inits (#177)
parent
cc766aa5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
167 deletions
+185
-167
tests/optim/test_oss.py
tests/optim/test_oss.py
+185
-167
No files found.
tests/optim/test_oss.py
View file @
543d5693
...
...
@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import
os
import
tempfile
import
unittest
import
numpy
as
np
import
pytest
...
...
@@ -23,28 +25,27 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
def
setup_module
(
module
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29500"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
BACKEND
):
url
=
"file://"
+
tempfile_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
class
TestSingleRank
(
unittest
.
TestCase
):
"""
All the following tests do not check for inter-process communication
"""
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
def
setUp
(
self
):
dist_init
(
0
,
1
,
tempfile
.
mkstemp
()[
1
])
def
tearDown
(
self
):
torch
.
distributed
.
destroy_process_group
()
def
test_create
():
def
test_create
(
self
):
params
=
[
torch
.
rand
(
1
)]
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
def
test_state_dict
():
def
test_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
x
.
backward
()
...
...
@@ -88,8 +89,7 @@ def test_state_dict():
# Check that the exposed param_groups are on the proper device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
def
test_lr_scheduler
():
def
test_lr_scheduler
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x2
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
...
...
@@ -107,8 +107,7 @@ def test_lr_scheduler():
s2
.
step
()
assert
x
==
x2
def
test_step_with_kwargs
():
def
test_step_with_kwargs
(
self
):
class
SGDWithStepKWArg
(
torch
.
optim
.
SGD
):
def
step
(
self
,
closure
=
None
,
kwarg
=
[]):
super
().
step
()
...
...
@@ -122,8 +121,7 @@ def test_step_with_kwargs():
assert
kwarg
==
[
5
]
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_with_extra_inner_key
():
def
test_step_with_extra_inner_key
(
self
):
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
# Dummy optimizer which adds a new key to the param groups
def
step
(
self
,
closure
=
None
):
...
...
@@ -137,8 +135,7 @@ def test_step_with_extra_inner_key():
assert
o
.
param_groups
[
0
][
"new_key"
]
==
0.1
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_step_without_closure
():
def
test_step_without_closure
(
self
):
class
SGDWithoutClosure
(
torch
.
optim
.
SGD
):
def
step
(
self
):
return
super
().
step
()
...
...
@@ -149,8 +146,7 @@ def test_step_without_closure():
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_local_state_dict
():
def
test_local_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
local_state_dict
()
...
...
@@ -163,8 +159,7 @@ def test_local_state_dict():
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_implicit_local_state_dict
():
def
test_implicit_local_state_dict
(
self
):
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
state_dict
()
...
...
@@ -178,8 +173,8 @@ def test_implicit_local_state_dict():
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
run_test_add_param_group
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_add_param_group
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
...
...
@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
assert
len
(
o
.
optim
.
param_groups
)
==
2
dist
.
destroy_process_group
()
def
test_add_param_group
():
world_size
=
3
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_zero_grad
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_zero_grad
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
x
=
torch
.
rand
(
1
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
...
...
@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert
not
m
.
weight
.
grad
assert
not
m
.
bias
.
grad
dist
.
destroy_process_group
()
def
test_zero_grad
():
world_size
=
2
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_step
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
"gloo"
)
x
=
torch
.
tensor
([
float
(
rank
+
1
)],
device
=
rank
)
m
=
torch
.
nn
.
Linear
(
1
,
1
)
m
.
weight
.
data
=
torch
.
tensor
([[
1.0
]])
...
...
@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_step
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step_with_closure
(
rank
,
world_size
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
)
def
run_test_step_with_closure
(
rank
,
world_size
,
tempfile_name
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
x_val
=
rank
+
1
weight
=
1.0
...
...
@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert
m
.
weight
==
torch
.
tensor
([[
1.1
]],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
2.1
],
device
=
rank
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_step_with_closure
():
world_size
=
min
(
2
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_step_with_closure
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_sharding
(
rank
,
world_size
):
dist_init
(
rank
,
world_size
)
def
run_test_sharding
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
params
=
[]
for
size
in
[
5
,
4
,
2
,
6
,
4
,
3
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
o
=
optim
.
OSS
(
params
,
lr
=
0.1
)
assert
sum
([
x
.
numel
()
for
x
in
o
.
optim
.
param_groups
[
0
][
"params"
]])
==
8
dist
.
destroy_process_group
()
def
test_sharding
():
world_size
=
3
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
):
dist_init
(
rank
,
world_size
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
batch
,
input_width
,
hidden
,
target_width
=
3
,
3
,
3
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
...
...
@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optimizer_state_dict
)
dist
.
destroy_process_group
()
def
test_collect_shards
():
world_size
=
3
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
reference_rank
=
0
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_multiple_groups
(
rank
,
world_size
):
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
tempfile_name
=
tempfile_name
,
backend
=
"gloo"
)
sub_group_ranks
=
[
0
,
2
,
4
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
...
...
@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
check
(
optimizer
)
dist
.
destroy_process_group
(
process_group
)
dist
.
destroy_process_group
()
def
test_multiple_groups
():
world_size
=
6
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment