Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4e438ba1
Unverified
Commit
4e438ba1
authored
May 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
May 03, 2021
Browse files
[fix] SDP: expose module property fix + unit test (#647)
* fix + unit test * changelog update
parent
b66168da
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
CHANGELOG.md
CHANGELOG.md
+2
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+12
-10
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+1
-0
No files found.
CHANGELOG.md
View file @
4e438ba1
...
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
...
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
## NEXT - TBD
### Fixed
### Fixed
-
setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
-
setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
-
SDP: re-expose the module property (
[
#647
](
https://github.com/facebookresearch/fairscale/pull/647
)
)
### Added
### Added
-
FSDP: better memory usage for reduce bucket (
[
#633
](
https://github.com/facebookresearch/fairscale/pull/633
)
)
-
FSDP: better memory usage for reduce bucket (
[
#633
](
https://github.com/facebookresearch/fairscale/pull/633
)
)
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4e438ba1
...
@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
...
@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
_module
=
module
# This field needs to be exposed to insure interface parity with DDP
self
.
module
=
module
self
.
_sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_enable_broadcast_buffers
=
broadcast_buffers
self
.
_enable_broadcast_buffers
=
broadcast_buffers
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
...
@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
# device_id related logic is not present, this is not handled
devices
=
{
p
.
device
for
p
in
self
.
_
module
.
parameters
()}
devices
=
{
p
.
device
for
p
in
self
.
module
.
parameters
()}
self
.
is_multi_device_module
=
len
(
devices
)
>
1
self
.
is_multi_device_module
=
len
(
devices
)
>
1
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
_
module
.
parameters
()}
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
module
.
parameters
()}
assert
len
(
distinct_device_types
)
==
1
,
(
assert
len
(
distinct_device_types
)
==
1
,
(
"ShardedDataParallel's input module must be on "
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
"the same type of devices, but input module parameters are located on {} different device types."
...
@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - setup buckets and tensor views
# - setup buckets and tensor views
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
_
module
.
parameters
()])
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
module
.
parameters
()])
self
.
_buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
self
.
_buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
if
dist
.
get_world_size
(
self
.
_process_group
)
==
1
:
if
dist
.
get_world_size
(
self
.
_process_group
)
==
1
:
...
@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
# passing a handle to torch.nn.SyncBatchNorm layer
# passing a handle to torch.nn.SyncBatchNorm layer
self
.
_passing_sync_batchnorm_handle
(
self
.
_
module
)
self
.
_passing_sync_batchnorm_handle
(
self
.
module
)
# Make sure that all ranks start with the same model
# Make sure that all ranks start with the same model
if
sync_models_at_startup
:
if
sync_models_at_startup
:
...
@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
self
.
_clear_counters
()
self
.
_clear_counters
()
# Normal FW on the base model
# Normal FW on the base model
return
self
.
_
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
def
to
(
# type: ignore
self
,
self
,
...
@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
for
bucket
in
self
.
_buckets
[
_device
].
values
():
for
bucket
in
self
.
_buckets
[
_device
].
values
():
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
_
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
refresh_trainable
(
self
)
->
None
:
def
refresh_trainable
(
self
)
->
None
:
""" If the module trainability has changed, update all the assumptions """
""" If the module trainability has changed, update all the assumptions """
...
@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
work_handles
=
[]
for
buffer
in
self
.
_
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
)
...
@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
try
:
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
_
module
,
name
)
return
getattr
(
self
.
module
,
name
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
...
@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
work_handles
=
[]
for
t
in
self
.
_
module
.
state_dict
().
values
():
for
t
in
self
.
module
.
state_dict
().
values
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
t
,
src
=
self
.
_reference_global_rank
,
group
=
self
.
_process_group
,
async_op
=
True
)
dist
.
broadcast
(
t
,
src
=
self
.
_reference_global_rank
,
group
=
self
.
_process_group
,
async_op
=
True
)
)
)
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
4e438ba1
...
@@ -236,6 +236,7 @@ def test_ddp_attributes():
...
@@ -236,6 +236,7 @@ def test_ddp_attributes():
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
assert
hasattr
(
ddp_model
,
"module"
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment