Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4e438ba1
Unverified
Commit
4e438ba1
authored
May 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
May 03, 2021
Browse files
[fix] SDP: expose module property fix + unit test (#647)
* fix + unit test * changelog update
parent
b66168da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
CHANGELOG.md
CHANGELOG.md
+2
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+12
-10
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+1
-0
No files found.
CHANGELOG.md
View file @
4e438ba1
...
...
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
-
setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
-
SDP: re-expose the module property (
[
#647
](
https://github.com/facebookresearch/fairscale/pull/647
)
)
### Added
-
FSDP: better memory usage for reduce bucket (
[
#633
](
https://github.com/facebookresearch/fairscale/pull/633
)
)
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4e438ba1
...
...
@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
):
super
().
__init__
()
self
.
_module
=
module
# This field needs to be exposed to insure interface parity with DDP
self
.
module
=
module
self
.
_sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_enable_broadcast_buffers
=
broadcast_buffers
self
.
_auto_refresh_trainable
=
auto_refresh_trainable
...
...
@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
devices
=
{
p
.
device
for
p
in
self
.
_
module
.
parameters
()}
devices
=
{
p
.
device
for
p
in
self
.
module
.
parameters
()}
self
.
is_multi_device_module
=
len
(
devices
)
>
1
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
_
module
.
parameters
()}
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
module
.
parameters
()}
assert
len
(
distinct_device_types
)
==
1
,
(
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
...
...
@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - setup buckets and tensor views
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
_
module
.
parameters
()])
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
module
.
parameters
()])
self
.
_buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
if
dist
.
get_world_size
(
self
.
_process_group
)
==
1
:
...
...
@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
# passing a handle to torch.nn.SyncBatchNorm layer
self
.
_passing_sync_batchnorm_handle
(
self
.
_
module
)
self
.
_passing_sync_batchnorm_handle
(
self
.
module
)
# Make sure that all ranks start with the same model
if
sync_models_at_startup
:
...
...
@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
self
.
_clear_counters
()
# Normal FW on the base model
return
self
.
_
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
self
,
...
...
@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
for
bucket
in
self
.
_buckets
[
_device
].
values
():
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
_
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
refresh_trainable
(
self
)
->
None
:
""" If the module trainability has changed, update all the assumptions """
...
...
@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
for
buffer
in
self
.
_
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
...
...
@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
return
getattr
(
self
.
_
module
,
name
)
return
getattr
(
self
.
module
,
name
)
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
...
...
@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
for
t
in
self
.
_
module
.
state_dict
().
values
():
for
t
in
self
.
module
.
state_dict
().
values
():
work_handles
.
append
(
dist
.
broadcast
(
t
,
src
=
self
.
_reference_global_rank
,
group
=
self
.
_process_group
,
async_op
=
True
)
)
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
4e438ba1
...
...
@@ -236,6 +236,7 @@ def test_ddp_attributes():
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
assert
hasattr
(
ddp_model
,
"module"
)
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment