Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
1204c7cf
Unverified
Commit
1204c7cf
authored
Mar 06, 2021
by
Myle Ott
Committed by
GitHub
Mar 06, 2021
Browse files
[perf] FSDP: speed up no_sync and test communication volume (#470)
parent
0b8d0753
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
231 additions
and
131 deletions
+231
-131
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+93
-50
tests/ci_test_list_3.txt
tests/ci_test_list_3.txt
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+1
-80
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_no_sync.py
+131
-0
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
+5
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
1204c7cf
...
@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module):
# Enum to indicate if we're in the forward/backward pass, idle, etc.
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
self
.
training_state
=
TrainingState
.
IDLE
# Flag to indicate if the full params are gathered.
self
.
has_full_params
:
bool
=
False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
...
@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module):
A context manager to disable gradient synchronizations across DDP
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
variables, which will later be synchronized in the first
forward-backward pass exiting the context.
forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
"""
"""
self
.
_lazy_init
()
self
.
_lazy_init
()
assert
self
.
_is_root
,
"no_sync on inner FSDP is not supported"
assert
self
.
_is_root
,
"no_sync on inner FSDP is not supported"
...
@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module):
# forward/backward.
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
full_tensors
=
self
.
_rebuild_full_params
(
full_precision
=
True
)
full_tensors
=
self
.
_rebuild_full_params
(
full_precision
=
True
)
assert
full_tensors
is
not
None
with
contextlib
.
ExitStack
()
as
stack
:
with
contextlib
.
ExitStack
()
as
stack
:
if
self
.
flatten_parameters
and
self
.
module
.
is_flattened
:
if
self
.
flatten_parameters
and
self
.
module
.
is_flattened
:
# Update flattened views to point to fully-sized tensors. We
# Update flattened views to point to fully-sized tensors. We
...
@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module):
p
.
_fp32_shard
.
copy_
(
local_shard
.
view_as
(
p
.
_fp32_shard
))
p
.
_fp32_shard
.
copy_
(
local_shard
.
view_as
(
p
.
_fp32_shard
))
if
safe_to_free
:
if
safe_to_free
:
free_storage_
(
full_tensor
)
free_storage_
(
full_tensor
)
self
.
has_full_params
=
False
self
.
_use_fp32_param_shard
()
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
self
.
training_state
=
TrainingState
.
IDLE
...
@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
_rebuild_full_params
()
self
.
_rebuild_full_params
()
else
:
else
:
self
.
_use_full_params
()
self
.
_use_full_params
()
# Make sure p.grad has the correct size/device (or set it to None).
# Make sure p.grad has the correct size/device (or set it to None).
self
.
_prep_grads_for_backward
()
self
.
_prep_grads_for_backward
()
...
@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module):
if
param
.
grad
.
requires_grad
:
if
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
# Free full params and switch to FP32 shard after backward.
if
not
self
.
_is_root
or
self
.
_require_backward_grad_sync
:
self
.
_free_full_params
([
param
])
# Free full params. As a special case, we don't free the full params
self
.
_use_fp32_param_shard
([
param
])
# on the root instance when in a ``no_sync`` context (as indicated
# by ``self._require_backward_grad_sync``), since we will need the
# params again immediately for the next forward.
self
.
_free_full_params
([
param
])
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
# This is a no-op if reshard_after_forward is True, since we already
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
# pre_backward_hook.
self
.
_free_fp16_param_shard
([
param
])
self
.
_free_fp16_param_shard
([
param
])
# Switch to FP32 shard after backward.
self
.
_use_fp32_param_shard
([
param
])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
...
@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module):
def
_queue_wait_for_post_backward
(
self
)
->
None
:
def
_queue_wait_for_post_backward
(
self
)
->
None
:
"""Try to queue a `wait_for_post_backward` callback.
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the
Only called on root and only queue one callback. But can be called by
root instance doesn't own any params.
children FSDPs via a closure in case the root instance doesn't own any
params.
"""
"""
assert
self
.
_is_root
assert
self
.
_is_root
self
.
assert_state
(
TrainingState
.
BACKWARD
)
self
.
assert_state
(
TrainingState
.
BACKWARD
)
...
@@ -978,18 +996,18 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -978,18 +996,18 @@ class FullyShardedDataParallel(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_wait_for_post_backward
(
self
)
->
None
:
def
_wait_for_post_backward
(
self
)
->
None
:
"""Wait for post-backward work to finish. Only called on root instance.
"""Wait for post-backward to finish. Only called on root instance."""
"""
assert
self
.
_is_root
assert
self
.
_is_root
self
.
assert_state
(
TrainingState
.
BACKWARD
)
self
.
assert_state
(
TrainingState
.
BACKWARD
)
# Flush any unreduced buckets in the post_backward stream.
if
self
.
_require_backward_grad_sync
:
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
# Flush any unreduced buckets in the post_backward stream.
assert
self
.
_reducer
is
not
None
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
self
.
_reducer
.
flush
()
assert
self
.
_reducer
is
not
None
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
self
.
_reducer
.
flush
()
if
self
.
move_grads_to_cpu
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
if
self
.
move_grads_to_cpu
:
torch
.
cuda
.
current_stream
().
synchronize
()
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
# A backward pass is done, update root and nested FSDP's flags.
# A backward pass is done, update root and nested FSDP's flags.
for
m
in
self
.
modules
():
# includes self
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
FullyShardedDataParallel
):
if
isinstance
(
m
,
FullyShardedDataParallel
):
...
@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module):
m
.
training_state
=
TrainingState
.
IDLE
m
.
training_state
=
TrainingState
.
IDLE
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_rebuild_full_params
(
self
,
full_precision
:
bool
=
False
)
->
List
[
Tuple
[
torch
.
Tensor
,
bool
]]:
def
_rebuild_full_params
(
self
,
full_precision
:
bool
=
False
)
->
Optional
[
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
]
:
"""
"""
Gather all shards of params.
Gather all shards of params.
...
@@ -1008,57 +1026,82 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1008,57 +1026,82 @@ class FullyShardedDataParallel(nn.Module):
(e.g., FP32), possibly in fresh storage.
(e.g., FP32), possibly in fresh storage.
Returns:
Returns:
a
list of tuples, where the first element is the full-sized param
A
list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
caller to free the full-sized param. This will be ``None`` if
``full_precision=False`` and the full params are already gathered.
"""
"""
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
def
update_p_data
(
custom_output_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
if
custom_output_tensor
is
not
None
:
assert
p
.
_is_sharded
p
.
data
=
custom_output_tensor
output_tensors
.
append
((
p
.
data
,
True
))
elif
not
p
.
_is_sharded
:
if
self
.
mixed_precision
and
not
full_precision
:
p
.
data
=
p
.
_fp16_shard
output_tensors
.
append
((
p
.
data
,
True
))
else
:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors
.
append
((
p
.
data
,
False
))
else
:
p
.
data
=
p
.
_full_param_padded
output_tensors
.
append
((
p
.
data
,
True
))
# Trim any padding and reshape to match original size.
p
.
data
=
p
.
data
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
# Early exit if we already have full params and don't need full precision.
if
self
.
has_full_params
and
not
full_precision
:
for
p
in
self
.
params
:
update_p_data
()
return
output_tensors
self
.
has_full_params
=
True
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
self
.
mixed_precision
and
not
full_precision
:
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_cast_fp32_param_shards_to_fp16
()
self
.
_cast_fp32_param_shards_to_fp16
()
for
p
in
self
.
params
:
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
# e.g., when world_size == 1
if
not
p
.
_is_sharded
:
# e.g., when world_size == 1
if
self
.
mixed_precision
and
not
full_precision
:
update_p_data
()
p
.
data
=
p
.
_fp16_shard
output_tensors
.
append
((
p
.
data
,
True
))
else
:
output_tensors
.
append
((
p
.
data
,
False
))
continue
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data
=
p
.
data
.
to
(
p
.
_full_param_padded
.
device
)
p_size
=
p
.
_full_param_padded
.
size
()
assert
p_size
.
numel
()
%
self
.
world_size
==
0
if
not
self
.
mixed_precision
or
not
full_precision
:
if
p
.
_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage_
(
p
.
_full_param_padded
,
size
=
p_size
)
output_tensor
=
p
.
_full_param_padded
else
:
else
:
# Allocate fresh tensor in full precision.
# If self.cpu_offload and full_precision, we need to cast
output_tensor
=
p_data
.
new_zeros
(
p_size
)
# the FP32 CPU param to CUDA for the all-gather.
output_tensors
.
append
((
output_tensor
,
True
))
p_data
=
p
.
data
.
to
(
p
.
_full_param_padded
.
device
)
p_size
=
p
.
_full_param_padded
.
size
()
assert
p_size
.
numel
()
%
self
.
world_size
==
0
if
not
self
.
mixed_precision
or
not
full_precision
:
if
p
.
_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage_
(
p
.
_full_param_padded
,
size
=
p_size
)
output_tensor
=
p
.
_full_param_padded
else
:
# Allocate fresh tensor in full precision.
output_tensor
=
p_data
.
new_zeros
(
p_size
)
# Fill output_tensor with (p.data for each shard in self.world_size)
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks
=
list
(
output_tensor
.
chunk
(
self
.
world_size
))
chunks
=
list
(
output_tensor
.
chunk
(
self
.
world_size
))
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
p
.
data
=
output_tensor
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
# Set p.data = output_tensor (with padding trimmed)
update_p_data
(
output_tensor
)
if
self
.
mixed_precision
and
not
full_precision
:
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_free_fp16_param_shard
([
p
])
self
.
_free_fp16_param_shard
([
p
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
return
output_tensors
return
output_tensors
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_use_full_params
(
self
)
->
None
:
def
_use_full_params
(
self
)
->
None
:
"""Switching p.data pointers to use the full params.
"""
Switch p.data pointers to use the full params.
Note: this
is used
assum
ing
full param
gathering is
already
done
.
Note: this assum
es
full param
s are
already
gathered
.
"""
"""
assert
self
.
has_full_params
for
p
in
self
.
params
:
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
if
not
p
.
_is_sharded
:
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
...
@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module):
"""Free up storage for full parameters."""
"""Free up storage for full parameters."""
if
params
is
None
:
if
params
is
None
:
params
=
self
.
params
params
=
self
.
params
self
.
has_full_params
=
False
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
for
p
in
params
:
for
p
in
params
:
...
@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None:
...
@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
# is the sole occupant of the Storage.
assert
data
.
storage_offset
()
==
0
assert
data
.
storage_offset
()
==
0
assert
data
.
storage
().
size
()
==
data
.
numel
()
data
.
storage
().
resize_
(
0
)
data
.
storage
().
resize_
(
0
)
...
...
tests/ci_test_list_3.txt
View file @
1204c7cf
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
1204c7cf
...
@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest):
...
@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
no_grad_output
=
model
(
*
input
)
no_grad_output
=
model
(
*
input
)
assert
objects_are_equal
(
ref_output
,
no_grad_output
),
"no_grad_output did not match ref_output"
assert
objects_are_equal
(
ref_output
,
no_grad_output
,
raise_exception
=
True
)
class
TestNoSync
(
DistributedTest
):
def
test_transformer
(
self
):
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
{})
spawn_and_init
(
fn
)
def
test_transformer_no_flat_params
(
self
):
config
=
{
"flatten_parameters"
:
False
}
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
config
)
spawn_and_init
(
fn
)
def
test_nested_wrapper
(
self
):
fn
=
functools
.
partial
(
self
.
_test_nested_wrapper
,
config
=
{})
spawn_and_init
(
fn
)
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{},
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
@
classmethod
def
_test_nested_wrapper
(
self
,
rank
,
group
,
config
):
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
self
.
_test_no_sync
(
model
,
batch_dim
=
0
)
@
classmethod
def
_test_no_sync
(
self
,
model
,
batch_dim
):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
assert
isinstance
(
batch1
,
tuple
)
batch2
=
tuple
(
# This randomly permutes the values in a multi-dim tensor.
x
.
view
(
-
1
)[
torch
.
randperm
(
x
.
numel
())].
view_as
(
x
)
for
x
in
batch1
)
for
x
,
y
in
zip
(
batch1
,
batch2
):
assert
not
torch
.
all
(
x
==
y
)
# Concat the batches along batch dimension.
concat_batch
=
tuple
(
torch
.
cat
((
x
,
y
),
dim
=
batch_dim
)
for
(
x
,
y
)
in
zip
(
batch1
,
batch2
))
# Establish reference behavior on the concat batch.
model
.
zero_grad
()
output
=
model
(
*
concat_batch
)
ref_loss
=
model
.
module
.
get_loss
(
concat_batch
,
output
)
ref_loss
.
backward
()
ref_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
# Test that we get the same results by accumulating grads.
model
.
zero_grad
()
with
model
.
no_sync
():
# accumulate gradients from the first batch
output
=
model
(
*
batch1
)
loss1
=
model
.
module
.
get_loss
(
batch1
,
output
)
loss1
.
backward
()
output
=
model
(
*
batch2
)
loss2
=
model
.
module
.
get_loss
(
batch2
,
output
)
loss2
.
backward
()
accumulated_loss
=
loss1
+
loss2
accumulated_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
torch
.
testing
.
assert_allclose
(
ref_loss
,
accumulated_loss
)
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
class
TransformerWithSharedParams
(
nn
.
Module
):
class
TransformerWithSharedParams
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_no_sync.py
0 → 100644
View file @
1204c7cf
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
unittest
from
unittest.mock
import
patch
import
torch
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
DummyProcessGroup
,
objects_are_equal
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
spawn_and_init
class
TestNoSync
(
DistributedTest
):
def
test_transformer
(
self
):
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
{})
spawn_and_init
(
fn
)
def
test_transformer_no_flat_params
(
self
):
config
=
{
"flatten_parameters"
:
False
}
fn
=
functools
.
partial
(
self
.
_test_transformer
,
config
=
config
)
spawn_and_init
(
fn
)
def
test_nested_wrapper
(
self
):
fn
=
functools
.
partial
(
self
.
_test_nested_wrapper
,
config
=
{})
spawn_and_init
(
fn
)
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{},
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
@
classmethod
def
_test_nested_wrapper
(
self
,
rank
,
group
,
config
):
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
self
.
_test_no_sync
(
model
,
batch_dim
=
0
)
@
classmethod
def
_test_no_sync
(
self
,
model
,
batch_dim
):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
assert
isinstance
(
batch1
,
tuple
)
batch2
=
tuple
(
# This randomly permutes the values in a multi-dim tensor.
x
.
view
(
-
1
)[
torch
.
randperm
(
x
.
numel
())].
view_as
(
x
)
for
x
in
batch1
)
for
x
,
y
in
zip
(
batch1
,
batch2
):
assert
not
torch
.
all
(
x
==
y
)
# Concat the batches along batch dimension.
concat_batch
=
tuple
(
torch
.
cat
((
x
,
y
),
dim
=
batch_dim
)
for
(
x
,
y
)
in
zip
(
batch1
,
batch2
))
# Establish reference behavior on the concat batch.
model
.
zero_grad
()
output
=
model
(
*
concat_batch
)
ref_loss
=
model
.
module
.
get_loss
(
concat_batch
,
output
)
ref_loss
.
backward
()
ref_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
# Test that we get the same results by accumulating grads.
model
.
zero_grad
()
with
model
.
no_sync
():
# accumulate gradients from the first batch
output
=
model
(
*
batch1
)
loss1
=
model
.
module
.
get_loss
(
batch1
,
output
)
loss1
.
backward
()
output
=
model
(
*
batch2
)
loss2
=
model
.
module
.
get_loss
(
batch2
,
output
)
loss2
.
backward
()
accumulated_loss
=
loss1
+
loss2
accumulated_grads
=
[
p
.
grad
.
detach
().
clone
()
for
p
in
model
.
parameters
()]
torch
.
testing
.
assert_allclose
(
ref_loss
,
accumulated_loss
)
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
class
TestNoSyncCommunication
(
DistributedTest
):
def
test_communication
(
self
):
config
=
{
"mixed_precision"
:
True
}
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
)
spawn_and_init
(
fn
)
@
classmethod
def
_test_communication
(
self
,
rank
,
group
,
config
):
if
group
.
size
()
==
1
:
return
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
patch
(
"torch.distributed.all_gather"
)
as
mock_all_gather
:
with
patch
(
"torch.distributed.reduce_scatter"
)
as
mock_reduce_scatter
:
with
model
.
no_sync
():
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
0
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
1
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/nn/data_parallel/test_fsdp_summon_full_params.py
View file @
1204c7cf
...
@@ -48,7 +48,11 @@ class TestMemory(DistributedTest):
...
@@ -48,7 +48,11 @@ class TestMemory(DistributedTest):
del
state_dict
del
state_dict
mems
.
append
(
get_cuda_mem
())
mems
.
append
(
get_cuda_mem
())
assert
mems
[
4
]
==
mems
[
0
]
# Any value other than `==` indicates a memory leak. If mems[4] >
# mems[0], that indicates we're not cleaning up params properly in
# summon_full_params. If mems[4] < mems[0], that indicates there's a
# memory leak in _train_for_several_steps.
assert
mems
[
4
]
==
mems
[
0
],
f
"memory leak detected,
{
mems
[
4
]
}
!=
{
mems
[
0
]
}
"
class
TestPersistence
(
DistributedTest
):
class
TestPersistence
(
DistributedTest
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment