Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8c8a625a
Unverified
Commit
8c8a625a
authored
Apr 29, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 29, 2021
Browse files
[test][minor] Improving SDP test coverage (#639)
* Improving test coverage on SDP * using pytest exception catcher
parent
21cba91b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
7 deletions
+39
-7
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+4
-2
fairscale/optim/utils.py
fairscale/optim/utils.py
+2
-2
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+32
-2
No files found.
.pre-commit-config.yaml
View file @
8c8a625a
...
@@ -40,6 +40,6 @@ repos:
...
@@ -40,6 +40,6 @@ repos:
additional_dependencies
:
[
toml
]
additional_dependencies
:
[
toml
]
-
repo
:
https://github.com/pre-commit/mirrors-mypy
-
repo
:
https://github.com/pre-commit/mirrors-mypy
rev
:
'
v0.7
7
0'
rev
:
'
v0.7
9
0'
hooks
:
hooks
:
-
id
:
mypy
-
id
:
mypy
fairscale/nn/data_parallel/sharded_ddp.py
View file @
8c8a625a
...
@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module):
Module: self.
Module: self.
"""
"""
assert
device
in
self
.
_buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
(
assert
(
len
(
self
.
_buckets
.
keys
())
==
1
len
(
self
.
_buckets
.
keys
())
==
0
or
device
in
self
.
_buckets
.
keys
()
),
"Changing devices is not supported, because this would break OSSs state"
assert
(
len
(
self
.
_buckets
.
keys
())
<
2
),
"Several devices specified to begin with, incompatible with setting a single device here"
),
"Several devices specified to begin with, incompatible with setting a single device here"
for
_device
in
self
.
_buckets
.
keys
():
for
_device
in
self
.
_buckets
.
keys
():
...
...
fairscale/optim/utils.py
View file @
8c8a625a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
collections
from
collections
import
abc
import
io
import
io
from
math
import
inf
from
math
import
inf
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
...
@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
...
@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return
values
if
isinstance
(
value
,
list
)
else
tuple
(
values
)
return
values
if
isinstance
(
value
,
list
)
else
tuple
(
values
)
if
isinstance
(
value
,
collections
.
abc
.
Mapping
):
if
isinstance
(
value
,
abc
.
Mapping
):
device_val
:
Dict
[
str
,
Any
]
=
{}
device_val
:
Dict
[
str
,
Any
]
=
{}
for
key
,
val
in
value
.
items
():
for
key
,
val
in
value
.
items
():
device_val
[
key
]
=
recursive_copy_to_device
(
val
,
non_blocking
=
non_blocking
,
device
=
device
)
device_val
[
key
]
=
recursive_copy_to_device
(
val
,
non_blocking
=
non_blocking
,
device
=
device
)
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
8c8a625a
...
@@ -72,6 +72,7 @@ def run_one_step(
...
@@ -72,6 +72,7 @@ def run_one_step(
grad_accumulation
,
grad_accumulation
,
reduce_buffer_size
,
reduce_buffer_size
,
optimizer_type
,
optimizer_type
,
reduce_fp16
=
False
,
):
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
if
device
==
torch
.
device
(
"cuda"
):
if
device
==
torch
.
device
(
"cuda"
):
...
@@ -93,7 +94,11 @@ def run_one_step(
...
@@ -93,7 +94,11 @@ def run_one_step(
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
optimizer_type
,
**
optimizer_settings
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
optimizer_type
,
**
optimizer_settings
)
ddp_model
=
ShardedDataParallel
(
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
broadcast_buffers
=
broadcast_buffers
,
reduce_buffer_size
=
reduce_buffer_size
model
,
optimizer
,
broadcast_buffers
=
broadcast_buffers
,
reduce_buffer_size
=
reduce_buffer_size
,
reduce_fp16
=
reduce_fp16
,
)
)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
...
@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
...
@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@
pytest
.
mark
.
parametrize
(
"grad_accumulation"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"grad_accumulation"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
@
pytest
.
mark
.
parametrize
(
"optimizer_type"
,
[
torch
.
optim
.
SGD
,
SGDWithPausingCompute
])
@
pytest
.
mark
.
parametrize
(
"optimizer_type"
,
[
torch
.
optim
.
SGD
,
SGDWithPausingCompute
])
@
pytest
.
mark
.
parametrize
(
"reduce_fp16"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"setup"
,
"setup"
,
[
[
...
@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
...
@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
[
dist
.
Backend
.
GLOO
,
torch
.
device
(
"cuda"
)],
[
dist
.
Backend
.
GLOO
,
torch
.
device
(
"cuda"
)],
],
],
)
)
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
setup
):
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
world_size
=
2
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
...
@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize
...
@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize
grad_accumulation
,
grad_accumulation
,
reduce_buffer_size
,
reduce_buffer_size
,
optimizer_type
,
optimizer_type
,
reduce_fp16
,
),
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
...
@@ -248,6 +255,26 @@ def test_random_attributes():
...
@@ -248,6 +255,26 @@ def test_random_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_catch_grad_grad
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
chained_grad
=
torch
.
zeros_like
(
next
(
model
.
parameters
()))
chained_grad
.
requires_grad
=
True
next
(
model
.
parameters
()).
grad
=
chained_grad
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
inputs
=
torch
.
rand
(
100
,
2
)
with
pytest
.
raises
(
RuntimeError
):
_
=
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
def
test_mixed_types
():
def
test_mixed_types
():
# Check that ShardedDDP exposes the original module's attributes
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
...
@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
...
@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
except
AssertionError
:
except
AssertionError
:
pass
pass
# Check that we can change the data type
ddp_model
.
to
(
device
=
torch
.
device
(
"cpu"
),
dtype
=
torch
.
float16
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment