Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4e438ba1
Unverified
Commit
4e438ba1
authored
May 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
May 03, 2021
Browse files
[fix] SDP: expose module property fix + unit test (#647)
* fix + unit test * changelog update
parent
b66168da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
CHANGELOG.md
CHANGELOG.md
+2
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+12
-10
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+1
-0
No files found.
CHANGELOG.md
View file @
4e438ba1
...
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
...
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
## NEXT - TBD
### Fixed
### Fixed
-
setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
-
setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
-
SDP: re-expose the module property (
[
#647
](
https://github.com/facebookresearch/fairscale/pull/647
)
)
### Added
### Added
-
FSDP: better memory usage for reduce bucket (
[
#633
](
https://github.com/facebookresearch/fairscale/pull/633
)
)
-
FSDP: better memory usage for reduce bucket (
[
#633
](
https://github.com/facebookresearch/fairscale/pull/633
)
)
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4e438ba1
...
@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
...
@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
_module
=
module
# This field needs to be exposed to insure interface parity with DDP
self
.
module
=
module
self
.
_sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_enable_broadcast_buffers
=
broadcast_buffers
self
.
_enable_broadcast_buffers
=
broadcast_buffers
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
...
@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
# device_id related logic is not present, this is not handled
devices
=
{
p
.
device
for
p
in
self
.
_
module
.
parameters
()}
devices
=
{
p
.
device
for
p
in
self
.
module
.
parameters
()}
self
.
is_multi_device_module
=
len
(
devices
)
>
1
self
.
is_multi_device_module
=
len
(
devices
)
>
1
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
_
module
.
parameters
()}
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
module
.
parameters
()}
assert
len
(
distinct_device_types
)
==
1
,
(
assert
len
(
distinct_device_types
)
==
1
,
(
"ShardedDataParallel's input module must be on "
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
"the same type of devices, but input module parameters are located on {} different device types."
...
@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - setup buckets and tensor views
# - setup buckets and tensor views
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
_
module
.
parameters
()])
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
module
.
parameters
()])
self
.
_buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
self
.
_buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
if
dist
.
get_world_size
(
self
.
_process_group
)
==
1
:
if
dist
.
get_world_size
(
self
.
_process_group
)
==
1
:
...
@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
# passing a handle to torch.nn.SyncBatchNorm layer
# passing a handle to torch.nn.SyncBatchNorm layer
self
.
_passing_sync_batchnorm_handle
(
self
.
_
module
)
self
.
_passing_sync_batchnorm_handle
(
self
.
module
)
# Make sure that all ranks start with the same model
# Make sure that all ranks start with the same model
if
sync_models_at_startup
:
if
sync_models_at_startup
:
...
@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
self
.
_clear_counters
()
self
.
_clear_counters
()
# Normal FW on the base model
# Normal FW on the base model
return
self
.
_
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
def
to
(
# type: ignore
self
,
self
,
...
@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
for
bucket
in
self
.
_buckets
[
_device
].
values
():
for
bucket
in
self
.
_buckets
[
_device
].
values
():
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
_
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
refresh_trainable
(
self
)
->
None
:
def
refresh_trainable
(
self
)
->
None
:
""" If the module trainability has changed, update all the assumptions """
""" If the module trainability has changed, update all the assumptions """
...
@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
work_handles
=
[]
for
buffer
in
self
.
_
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
)
...
@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
try
:
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
_
module
,
name
)
return
getattr
(
self
.
module
,
name
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
...
@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
work_handles
=
[]
for
t
in
self
.
_
module
.
state_dict
().
values
():
for
t
in
self
.
module
.
state_dict
().
values
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
t
,
src
=
self
.
_reference_global_rank
,
group
=
self
.
_process_group
,
async_op
=
True
)
dist
.
broadcast
(
t
,
src
=
self
.
_reference_global_rank
,
group
=
self
.
_process_group
,
async_op
=
True
)
)
)
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
4e438ba1
...
@@ -236,6 +236,7 @@ def test_ddp_attributes():
...
@@ -236,6 +236,7 @@ def test_ddp_attributes():
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
assert
hasattr
(
ddp_model
,
"module"
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment