Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a265586b
Unverified
Commit
a265586b
authored
Feb 02, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 02, 2021
Browse files
[fix] ShardedDDP - properly handle post device change (#353)
* adding the .to(device) support + unit testing * doc update
parent
9e8929e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
6 deletions
+73
-6
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+44
-6
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+29
-0
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
a265586b
...
@@ -11,7 +11,7 @@ reduction automatically.
...
@@ -11,7 +11,7 @@ reduction automatically.
import
contextlib
import
contextlib
from
itertools
import
chain
from
itertools
import
chain
import
logging
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -142,7 +142,6 @@ class ShardedDataParallel(nn.Module):
...
@@ -142,7 +142,6 @@ class ShardedDataParallel(nn.Module):
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_iterator
:
Optional
[
Iterable
[
Bucket
]]
=
None
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
# - setup backward hooks which will be called by Torch's autograd in due time
# - setup backward hooks which will be called by Torch's autograd in due time
...
@@ -172,6 +171,44 @@ class ShardedDataParallel(nn.Module):
...
@@ -172,6 +171,44 @@ class ShardedDataParallel(nn.Module):
# Normal FW on the base model
# Normal FW on the base model
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
)
->
"ShardedDataParallel"
:
"""
Moves and/or casts the parameters and buffers.
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
.. note::
This method modifies the module in-place.
Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
non_blocking (bool): make it an asynchronous call.
Returns:
Module: self.
"""
for
optimizer
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
[
optimizer
].
keys
():
for
bucket
in
self
.
buckets
[
optimizer
][
device
]:
bucket
.
buffer
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
)
def
reduce
(
self
)
->
None
:
def
reduce
(
self
)
->
None
:
""".. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
...
@@ -215,14 +252,15 @@ class ShardedDataParallel(nn.Module):
...
@@ -215,14 +252,15 @@ class ShardedDataParallel(nn.Module):
def
_clear_counters
(
self
)
->
None
:
def
_clear_counters
(
self
)
->
None
:
"""Reset all the grad reduce and call counters"""
"""Reset all the grad reduce and call counters"""
if
not
self
.
should_accumulate_grads
:
if
not
self
.
should_accumulate_grads
:
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
for
o
in
self
.
buckets
.
keys
():
for
o
ptimizer
in
self
.
buckets
.
keys
():
for
d
in
self
.
buckets
[
o
].
keys
():
for
d
evice
in
self
.
buckets
[
o
ptimizer
].
keys
():
for
bucket
in
self
.
buckets
[
o
][
d
]:
for
bucket
in
self
.
buckets
[
o
ptimizer
][
device
]:
assert
bucket
.
sent
,
(
assert
bucket
.
sent
,
(
"A bucket failed
being
sent, probably unused parameters."
"A bucket failed
to be
sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
)
...
...
tests/nn/data_parallel/test_sharded_ddp.py
View file @
a265586b
...
@@ -342,6 +342,35 @@ def test_random_attributes():
...
@@ -342,6 +342,35 @@ def test_random_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
# Check that the wrapped module can change devices
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
cpu
()
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
ddp_model
.
to
(
device
)
inputs
=
torch
.
rand
((
10
,
2
),
device
=
device
)
outputs
=
ddp_model
(
inputs
)
# assert if the module has not been changed properly
loss
=
outputs
.
norm
().
backward
()
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
def
test_device_change
():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment