Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
1204c7cf
Unverified
Commit
1204c7cf
authored
Mar 06, 2021
by
Myle Ott
Committed by
GitHub
Mar 06, 2021
Browse files
[perf] FSDP: speed up no_sync and test communication volume (#470)
parent
0b8d0753
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
231 additions
and
131 deletions
+231
-131
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+93
-50
tests/ci_test_list_3.txt
tests/ci_test_list_3.txt
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+1
-80
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_no_sync.py
+131
-0
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
+5
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
1204c7cf
...
...
@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module):
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
# Flag to indicate if the full params are gathered.
self
.
has_full_params
:
bool
=
False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
...
...
@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module):
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass exiting the context.
forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
"""
self
.
_lazy_init
()
assert
self
.
_is_root
,
"no_sync on inner FSDP is not supported"
...
...
@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module):
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
full_tensors
=
self
.
_rebuild_full_params
(
full_precision
=
True
)
assert
full_tensors
is
not
None
with
contextlib
.
ExitStack
()
as
stack
:
if
self
.
flatten_parameters
and
self
.
module
.
is_flattened
:
# Update flattened views to point to fully-sized tensors. We
...
...
@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module):
p
.
_fp32_shard
.
copy_
(
local_shard
.
view_as
(
p
.
_fp32_shard
))
if
safe_to_free
:
free_storage_
(
full_tensor
)
self
.
has_full_params
=
False
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
...
...
@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
_rebuild_full_params
()
else
:
self
.
_use_full_params
()
# Make sure p.grad has the correct size/device (or set it to None).
self
.
_prep_grads_for_backward
()
...
...
@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module):
if
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
# Free full params and switch to FP32 shard after backward.
if
not
self
.
_is_root
or
self
.
_require_backward_grad_sync
:
# Free full params. As a special case, we don't free the full params
# on the root instance when in a ``no_sync`` context (as indicated
# by ``self._require_backward_grad_sync``), since we will need the
# params again immediately for the next forward.
self
.
_free_full_params
([
param
])
self
.
_use_fp32_param_shard
([
param
])
if
self
.
mixed_precision
:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self
.
_free_fp16_param_shard
([
param
])
# Switch to FP32 shard after backward.
self
.
_use_fp32_param_shard
([
param
])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
...
...
@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module):
def
_queue_wait_for_post_backward
(
self
)
->
None
:
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the
root instance doesn't own any params.
Only called on root and only queue one callback. But can be called by
children FSDPs via a closure in case the root instance doesn't own any
params.
"""
assert
self
.
_is_root
self
.
assert_state
(
TrainingState
.
BACKWARD
)
...
...
@@ -978,10 +996,10 @@ class FullyShardedDataParallel(nn.Module):
@
torch
.
no_grad
()
def
_wait_for_post_backward
(
self
)
->
None
:
"""Wait for post-backward work to finish. Only called on root instance.
"""
"""Wait for post-backward to finish. Only called on root instance."""
assert
self
.
_is_root
self
.
assert_state
(
TrainingState
.
BACKWARD
)
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
assert
self
.
_reducer
is
not
None
...
...
@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module):
m
.
training_state
=
TrainingState
.
IDLE
@
torch
.
no_grad
()
def
_rebuild_full_params
(
self
,
full_precision
:
bool
=
False
)
->
List
[
Tuple
[
torch
.
Tensor
,
bool
]]:
def
_rebuild_full_params
(
self
,
full_precision
:
bool
=
False
)
->
Optional
[
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
]
:
"""
Gather all shards of params.
...
...
@@ -1008,26 +1026,49 @@ class FullyShardedDataParallel(nn.Module):
(e.g., FP32), possibly in fresh storage.
Returns:
a
list of tuples, where the first element is the full-sized param
A
list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
caller to free the full-sized param. This will be ``None`` if
``full_precision=False`` and the full params are already gathered.
"""
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_cast_fp32_param_shards_to_fp16
()
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
# e.g., when world_size == 1
def
update_p_data
(
custom_output_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
if
custom_output_tensor
is
not
None
:
assert
p
.
_is_sharded
p
.
data
=
custom_output_tensor
output_tensors
.
append
((
p
.
data
,
True
))
elif
not
p
.
_is_sharded
:
if
self
.
mixed_precision
and
not
full_precision
:
p
.
data
=
p
.
_fp16_shard
output_tensors
.
append
((
p
.
data
,
True
))
else
:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors
.
append
((
p
.
data
,
False
))
continue
else
:
p
.
data
=
p
.
_full_param_padded
output_tensors
.
append
((
p
.
data
,
True
))
# Trim any padding and reshape to match original size.
p
.
data
=
p
.
data
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
# Early exit if we already have full params and don't need full precision.
if
self
.
has_full_params
and
not
full_precision
:
for
p
in
self
.
params
:
update_p_data
()
return
output_tensors
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
self
.
has_full_params
=
True
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_cast_fp32_param_shards_to_fp16
()
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
# e.g., when world_size == 1
update_p_data
()
else
:
# If self.cpu_offload and full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data
=
p
.
data
.
to
(
p
.
_full_param_padded
.
device
)
p_size
=
p
.
_full_param_padded
.
size
()
...
...
@@ -1040,13 +1081,13 @@ class FullyShardedDataParallel(nn.Module):
else
:
# Allocate fresh tensor in full precision.
output_tensor
=
p_data
.
new_zeros
(
p_size
)
output_tensors
.
append
((
output_tensor
,
True
))
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks
=
list
(
output_tensor
.
chunk
(
self
.
world_size
))
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
p
.
data
=
output_tensor
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
# Set p.data = output_tensor (with padding trimmed)
update_p_data
(
output_tensor
)
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_free_fp16_param_shard
([
p
])
...
...
@@ -1055,10 +1096,12 @@ class FullyShardedDataParallel(nn.Module):
@
torch
.
no_grad
()
def
_use_full_params
(
self
)
->
None
:
"""Switching p.data pointers to use the full params.
"""
Switch p.data pointers to use the full params.
Note: this
is used
assum
ing
full param
gathering is
already
done
.
Note: this assum
es
full param
s are
already
gathered
.
"""
assert
self
.
has_full_params
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
if
self
.
mixed_precision
:
...
...
@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module):
"""Free up storage for full parameters."""
if
params
is
None
:
params
=
self
.
params
self
.
has_full_params
=
False
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
for
p
in
params
:
...
...
@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert
data
.
storage_offset
()
==
0
assert
data
.
storage
().
size
()
==
data
.
numel
()
data
.
storage
().
resize_
(
0
)
...
...
tests/ci_test_list_3.txt
View file @
1204c7cf
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
1204c7cf
...
...
@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest):
with
torch
.
no_grad
():
no_grad_output
=
model
(
*
input
)
assert
objects_are_equal
(
ref_output
,
no_grad_output
),
"no_grad_output did not match ref_output"
class
TestNoSync
(
DistributedTest
):
def
test_transformer
(
self
):
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
{})
spawn_and_init
(
fn
)
def
test_transformer_no_flat_params
(
self
):
config
=
{
"flatten_parameters"
:
False
}
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
config
)
spawn_and_init
(
fn
)
def
test_nested_wrapper
(
self
):
fn
=
functools
.
partial
(
self
.
_test_nested_wrapper
,
config
=
{})
spawn_and_init
(
fn
)
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{},
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
@
classmethod
def
_test_nested_wrapper
(
self
,
rank
,
group
,
config
):
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
self
.
_test_no_sync
(
model
,
batch_dim
=
0
)
@
classmethod
def
_test_no_sync
(
self
,
model
,
batch_dim
):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
assert
isinstance
(
batch1
,
tuple
)
batch2
=
tuple
(
# This randomly permutes the values in a multi-dim tensor.
x
.
view
(
-
1
)[
torch
.
randperm
(
x
.
numel
())].
view_as
(
x
)
for
x
in
batch1
)
for
x
,
y
in
zip
(
batch1
,
batch2
):
assert
not
torch
.
all
(
x
==
y
)
# Concat the batches along batch dimension.
concat_batch
=
tuple
(
torch
.
cat
((
x
,
y
),
dim
=
batch_dim
)
for
(
x
,
y
)
in
zip
(
batch1
,
batch2
))
# Establish reference behavior on the concat batch.
model
.
zero_grad
()
output
=
model
(
*
concat_batch
)
ref_loss
=
model
.
module
.
get_loss
(
concat_batch
,
output
)
ref_loss
.
backward
()
ref_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
# Test that we get the same results by accumulating grads.
model
.
zero_grad
()
with
model
.
no_sync
():
# accumulate gradients from the first batch
output
=
model
(
*
batch1
)
loss1
=
model
.
module
.
get_loss
(
batch1
,
output
)
loss1
.
backward
()
output
=
model
(
*
batch2
)
loss2
=
model
.
module
.
get_loss
(
batch2
,
output
)
loss2
.
backward
()
accumulated_loss
=
loss1
+
loss2
accumulated_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
torch
.
testing
.
assert_allclose
(
ref_loss
,
accumulated_loss
)
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
assert
objects_are_equal
(
ref_output
,
no_grad_output
,
raise_exception
=
True
)
class
TransformerWithSharedParams
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_no_sync.py
0 → 100644
View file @
1204c7cf
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
unittest
from
unittest.mock
import
patch
import
torch
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
DummyProcessGroup
,
objects_are_equal
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
spawn_and_init
class
TestNoSync
(
DistributedTest
):
def
test_transformer
(
self
):
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
{})
spawn_and_init
(
fn
)
def
test_transformer_no_flat_params
(
self
):
config
=
{
"flatten_parameters"
:
False
}
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
config
)
spawn_and_init
(
fn
)
def
test_nested_wrapper
(
self
):
fn
=
functools
.
partial
(
self
.
_test_nested_wrapper
,
config
=
{})
spawn_and_init
(
fn
)
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{},
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
@
classmethod
def
_test_nested_wrapper
(
self
,
rank
,
group
,
config
):
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
self
.
_test_no_sync
(
model
,
batch_dim
=
0
)
@
classmethod
def
_test_no_sync
(
self
,
model
,
batch_dim
):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
assert
isinstance
(
batch1
,
tuple
)
batch2
=
tuple
(
# This randomly permutes the values in a multi-dim tensor.
x
.
view
(
-
1
)[
torch
.
randperm
(
x
.
numel
())].
view_as
(
x
)
for
x
in
batch1
)
for
x
,
y
in
zip
(
batch1
,
batch2
):
assert
not
torch
.
all
(
x
==
y
)
# Concat the batches along batch dimension.
concat_batch
=
tuple
(
torch
.
cat
((
x
,
y
),
dim
=
batch_dim
)
for
(
x
,
y
)
in
zip
(
batch1
,
batch2
))
# Establish reference behavior on the concat batch.
model
.
zero_grad
()
output
=
model
(
*
concat_batch
)
ref_loss
=
model
.
module
.
get_loss
(
concat_batch
,
output
)
ref_loss
.
backward
()
ref_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
# Test that we get the same results by accumulating grads.
model
.
zero_grad
()
with
model
.
no_sync
():
# accumulate gradients from the first batch
output
=
model
(
*
batch1
)
loss1
=
model
.
module
.
get_loss
(
batch1
,
output
)
loss1
.
backward
()
output
=
model
(
*
batch2
)
loss2
=
model
.
module
.
get_loss
(
batch2
,
output
)
loss2
.
backward
()
accumulated_loss
=
loss1
+
loss2
accumulated_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
torch
.
testing
.
assert_allclose
(
ref_loss
,
accumulated_loss
)
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
class
TestNoSyncCommunication
(
DistributedTest
):
def
test_communication
(
self
):
config
=
{
"mixed_precision"
:
True
}
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
)
spawn_and_init
(
fn
)
@
classmethod
def
_test_communication
(
self
,
rank
,
group
,
config
):
if
group
.
size
()
==
1
:
return
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
patch
(
"torch.distributed.all_gather"
)
as
mock_all_gather
:
with
patch
(
"torch.distributed.reduce_scatter"
)
as
mock_reduce_scatter
:
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
0
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
1
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/nn/data_parallel/test_fsdp_summon_full_params.py
View file @
1204c7cf
...
...
@@ -48,7 +48,11 @@ class TestMemory(DistributedTest):
del
state_dict
mems
.
append
(
get_cuda_mem
())
assert
mems
[
4
]
==
mems
[
0
]
# Any value other than `==` indicates a memory leak. If mems[4] >
# mems[0], that indicates we're not cleaning up params properly in
# summon_full_params. If mems[4] < mems[0], that indicates there's a
# memory leak in _train_for_several_steps.
assert
mems
[
4
]
==
mems
[
0
],
f
"memory leak detected,
{
mems
[
4
]
}
!=
{
mems
[
0
]
}
"
class
TestPersistence
(
DistributedTest
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment