Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8c8a625a
Unverified
Commit
8c8a625a
authored
Apr 29, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 29, 2021
Browse files
[test][minor] Improving SDP test coverage (#639)
* Improving test coverage on SDP * using pytest exception catcher
parent
21cba91b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
7 deletions
+39
-7
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+4
-2
fairscale/optim/utils.py
fairscale/optim/utils.py
+2
-2
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+32
-2
No files found.
.pre-commit-config.yaml
View file @
8c8a625a
...
...
@@ -40,6 +40,6 @@ repos:
additional_dependencies
:
[
toml
]
-
repo
:
https://github.com/pre-commit/mirrors-mypy
rev
:
'
v0.7
7
0'
rev
:
'
v0.7
9
0'
hooks
:
-
id
:
mypy
fairscale/nn/data_parallel/sharded_ddp.py
View file @
8c8a625a
...
...
@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module):
Module: self.
"""
assert
device
in
self
.
_buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
(
len
(
self
.
_buckets
.
keys
())
==
1
len
(
self
.
_buckets
.
keys
())
==
0
or
device
in
self
.
_buckets
.
keys
()
),
"Changing devices is not supported, because this would break OSSs state"
assert
(
len
(
self
.
_buckets
.
keys
())
<
2
),
"Several devices specified to begin with, incompatible with setting a single device here"
for
_device
in
self
.
_buckets
.
keys
():
...
...
fairscale/optim/utils.py
View file @
8c8a625a
...
...
@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
collections
from
collections
import
abc
import
io
from
math
import
inf
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
...
...
@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return
values
if
isinstance
(
value
,
list
)
else
tuple
(
values
)
if
isinstance
(
value
,
collections
.
abc
.
Mapping
):
if
isinstance
(
value
,
abc
.
Mapping
):
device_val
:
Dict
[
str
,
Any
]
=
{}
for
key
,
val
in
value
.
items
():
device_val
[
key
]
=
recursive_copy_to_device
(
val
,
non_blocking
=
non_blocking
,
device
=
device
)
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
8c8a625a
...
...
@@ -72,6 +72,7 @@ def run_one_step(
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
=
False
,
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
if
device
==
torch
.
device
(
"cuda"
):
...
...
@@ -93,7 +94,11 @@ def run_one_step(
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
optimizer_type
,
**
optimizer_settings
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
broadcast_buffers
=
broadcast_buffers
,
reduce_buffer_size
=
reduce_buffer_size
model
,
optimizer
,
broadcast_buffers
=
broadcast_buffers
,
reduce_buffer_size
=
reduce_buffer_size
,
reduce_fp16
=
reduce_fp16
,
)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
...
...
@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@
pytest
.
mark
.
parametrize
(
"grad_accumulation"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
@
pytest
.
mark
.
parametrize
(
"optimizer_type"
,
[
torch
.
optim
.
SGD
,
SGDWithPausingCompute
])
@
pytest
.
mark
.
parametrize
(
"reduce_fp16"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"setup"
,
[
...
...
@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
[
dist
.
Backend
.
GLOO
,
torch
.
device
(
"cuda"
)],
],
)
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
setup
):
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
...
...
@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
),
nprocs
=
world_size
,
join
=
True
,
...
...
@@ -248,6 +255,26 @@ def test_random_attributes():
dist
.
destroy_process_group
()
def
test_catch_grad_grad
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
chained_grad
=
torch
.
zeros_like
(
next
(
model
.
parameters
()))
chained_grad
.
requires_grad
=
True
next
(
model
.
parameters
()).
grad
=
chained_grad
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
inputs
=
torch
.
rand
(
100
,
2
)
with
pytest
.
raises
(
RuntimeError
):
_
=
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
def
test_mixed_types
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
...
...
@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
except
AssertionError
:
pass
# Check that we can change the data type
ddp_model
.
to
(
device
=
torch
.
device
(
"cpu"
),
dtype
=
torch
.
float16
)
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment