Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a265586b
Unverified
Commit
a265586b
authored
Feb 02, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 02, 2021
Browse files
[fix] ShardedDDP - properly handle post device change (#353)
* adding the .to(device) support + unit testing * doc update
parent
9e8929e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
6 deletions
+73
-6
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+44
-6
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+29
-0
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
a265586b
...
@@ -11,7 +11,7 @@ reduction automatically.
...
@@ -11,7 +11,7 @@ reduction automatically.
import
contextlib
import
contextlib
from
itertools
import
chain
from
itertools
import
chain
import
logging
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -142,7 +142,6 @@ class ShardedDataParallel(nn.Module):
...
@@ -142,7 +142,6 @@ class ShardedDataParallel(nn.Module):
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_iterator
:
Optional
[
Iterable
[
Bucket
]]
=
None
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
# - setup backward hooks which will be called by Torch's autograd in due time
# - setup backward hooks which will be called by Torch's autograd in due time
...
@@ -172,6 +171,44 @@ class ShardedDataParallel(nn.Module):
...
@@ -172,6 +171,44 @@ class ShardedDataParallel(nn.Module):
# Normal FW on the base model
# Normal FW on the base model
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
)
->
"ShardedDataParallel"
:
"""
Moves and/or casts the parameters and buffers.
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
.. note::
This method modifies the module in-place.
Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
non_blocking (bool): make it an asynchronous call.
Returns:
Module: self.
"""
for
optimizer
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
[
optimizer
].
keys
():
for
bucket
in
self
.
buckets
[
optimizer
][
device
]:
bucket
.
buffer
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
)
def
reduce
(
self
)
->
None
:
def
reduce
(
self
)
->
None
:
""".. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
...
@@ -215,14 +252,15 @@ class ShardedDataParallel(nn.Module):
...
@@ -215,14 +252,15 @@ class ShardedDataParallel(nn.Module):
def
_clear_counters
(
self
)
->
None
:
def
_clear_counters
(
self
)
->
None
:
"""Reset all the grad reduce and call counters"""
"""Reset all the grad reduce and call counters"""
if
not
self
.
should_accumulate_grads
:
if
not
self
.
should_accumulate_grads
:
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
for
o
in
self
.
buckets
.
keys
():
for
o
ptimizer
in
self
.
buckets
.
keys
():
for
d
in
self
.
buckets
[
o
].
keys
():
for
d
evice
in
self
.
buckets
[
o
ptimizer
].
keys
():
for
bucket
in
self
.
buckets
[
o
][
d
]:
for
bucket
in
self
.
buckets
[
o
ptimizer
][
device
]:
assert
bucket
.
sent
,
(
assert
bucket
.
sent
,
(
"A bucket failed
being
sent, probably unused parameters."
"A bucket failed
to be
sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
)
...
...
tests/nn/data_parallel/test_sharded_ddp.py
View file @
a265586b
...
@@ -342,6 +342,35 @@ def test_random_attributes():
...
@@ -342,6 +342,35 @@ def test_random_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
# Check that the wrapped module can change devices
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
cpu
()
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
ddp_model
.
to
(
device
)
inputs
=
torch
.
rand
((
10
,
2
),
device
=
device
)
outputs
=
ddp_model
(
inputs
)
# assert if the module has not been changed properly
loss
=
outputs
.
norm
().
backward
()
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
def
test_device_change
():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment