Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d2924670
Unverified
Commit
d2924670
authored
Mar 02, 2021
by
Myle Ott
Committed by
GitHub
Mar 02, 2021
Browse files
[fix] Make state_dict all-gather FP32 params (#451)
parent
f3359550
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
317 additions
and
101 deletions
+317
-101
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+121
-55
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+55
-27
tests/ci_test_list_1.txt
tests/ci_test_list_1.txt
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+16
-10
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
+71
-0
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_uneven.py
+18
-9
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+35
-0
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
d2924670
...
@@ -180,7 +180,10 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -180,7 +180,10 @@ class FullyShardedDataParallel(nn.Module):
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
self
.
_has_params
=
len
(
params
)
>
0
self
.
_has_params
=
len
(
params
)
>
0
if
self
.
flatten_parameters
and
self
.
_has_params
:
if
not
self
.
_has_params
:
self
.
flatten_parameters
=
False
if
self
.
flatten_parameters
:
self
.
_fsdp_wrapped_module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
self
.
_fsdp_wrapped_module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
del
module
# free original module in case it helps garbage collection
del
module
# free original module in case it helps garbage collection
self
.
params
=
[
self
.
_fsdp_wrapped_module
.
flat_param
]
self
.
params
=
[
self
.
_fsdp_wrapped_module
.
flat_param
]
...
@@ -335,22 +338,27 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -335,22 +338,27 @@ class FullyShardedDataParallel(nn.Module):
continue
continue
p
.
_is_sharded
=
True
p
.
_is_sharded
=
True
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
p
.
data
).
chunk
(
self
.
world_size
))
while
len
(
chunks
)
<
self
.
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
self
.
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
# Replace p.data with the relevant shard.
# Replace p.data with the relevant shard.
orig_data
=
p
.
data
orig_data
=
p
.
data
p
.
data
=
chunks
[
self
.
rank
].
clone
()
# clone since we free storage below
p
.
data
=
self
.
_get_shard
(
p
.
data
)
if
num_to_pad
>
0
:
p
.
data
=
F
.
pad
(
p
.
data
,
[
0
,
num_to_pad
])
free_storage_
(
orig_data
)
free_storage_
(
orig_data
)
def
_get_shard
(
self
,
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Return the local shard of a given full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
tensor
).
chunk
(
self
.
world_size
))
while
len
(
chunks
)
<
self
.
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
self
.
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
self
.
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
return
shard
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
(
return
(
f
"rank=
{
self
.
rank
}
, world_size=
{
self
.
world_size
}
, "
f
"rank=
{
self
.
rank
}
, world_size=
{
self
.
world_size
}
, "
...
@@ -408,32 +416,34 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -408,32 +416,34 @@ class FullyShardedDataParallel(nn.Module):
Returns the whole (unsharded) state of the module. Parameters are not
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
wrapped Module without any sharding-specific logic. Returned tensors
will
always be typed float
32.
will
be full precision (e.g., FP
32
)
.
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
# Buffers dtype stays consistent with parameters.
# Buffers dtype stays consistent with parameters.
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
if
self
.
_return_full_state_dict
:
if
self
.
_return_full_state_dict
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
with
self
.
summon_full_params
():
with
self
.
summon_full_params
(
volatile
=
True
):
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
flatten_parameters
:
if
self
.
flatten_parameters
:
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
state_dict
=
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
state_dict
=
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
else
:
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
if
self
.
cpu_offload
:
for
k
in
state_dict
.
keys
():
state_dict
[
k
]
=
state_dict
[
k
].
cpu
()
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
...
@@ -516,29 +526,42 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -516,29 +526,42 @@ class FullyShardedDataParallel(nn.Module):
m
.
_require_backward_grad_sync
=
old_flag
m
.
_require_backward_grad_sync
=
old_flag
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
summon_full_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
def
summon_full_params
(
self
,
recurse
:
bool
=
True
,
volatile
:
bool
=
False
)
->
Generator
:
"""
"""
A context manager to expose full params for the current FSDP instance.
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking.
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
.. note:: This can be used on inner FSDPs.
.. note:: This can be used on inner FSDPs.
.. note:: This can *not* be used within a forward or backward pass. Nor
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
"""
if
recurse
:
if
recurse
:
with
contextlib
.
ExitStack
()
as
stack
:
with
contextlib
.
ExitStack
()
as
stack
:
#
s
ummon all params for any nested F
lattenParamsWrapper
instances
#
S
ummon all params for any nested F
SDP
instances
.
for
module
in
self
.
modules
():
for
module
in
self
.
modules
():
if
isinstance
(
module
,
FullyShardedDataParallel
):
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
summon_full_params
(
recurse
=
False
))
stack
.
enter_context
(
module
.
summon_full_params
(
recurse
=
False
,
volatile
=
volatile
))
#
y
ield to the caller, with full params in all nested instances
#
Y
ield to the caller, with full params in all nested instances
.
yield
yield
#
e
xiting from the ExitStack will re-shard params
#
E
xiting from the ExitStack will re-shard params
.
return
return
else
:
else
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -547,13 +570,30 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -547,13 +570,30 @@ class FullyShardedDataParallel(nn.Module):
# Set the state so that we assert when trying to go into
# Set the state so that we assert when trying to go into
# forward/backward.
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
_rebuild_full_params
()
full_tensors
=
self
.
_rebuild_full_params
(
full_precision
=
True
)
try
:
with
contextlib
.
ExitStack
()
as
stack
:
yield
if
self
.
flatten_parameters
and
self
.
module
.
is_flattened
:
finally
:
# Update flattened views to point to fully-sized tensors. We
self
.
_free_full_params
()
# use self.params[0] instead of full_tensors since the
self
.
_use_fp32_param_shard
()
# latter may contain padding.
self
.
training_state
=
TrainingState
.
IDLE
assert
len
(
self
.
params
)
==
1
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
stack
.
enter_context
(
self
.
module
.
unflatten_params
(
recurse
=
False
,
flat_param
=
self
.
params
[
0
]))
try
:
yield
finally
:
stack
.
close
()
assert
len
(
full_tensors
)
==
len
(
self
.
params
)
for
p
,
(
full_tensor
,
safe_to_free
)
in
zip
(
self
.
params
,
full_tensors
):
if
not
volatile
:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard
=
self
.
_get_shard
(
full_tensor
)
p
.
_fp32_shard
.
copy_
(
local_shard
.
view_as
(
p
.
_fp32_shard
))
if
safe_to_free
:
free_storage_
(
full_tensor
)
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
def
_reset_lazy_init
(
self
)
->
None
:
def
_reset_lazy_init
(
self
)
->
None
:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
...
@@ -953,35 +993,61 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -953,35 +993,61 @@ class FullyShardedDataParallel(nn.Module):
m
.
training_state
=
TrainingState
.
IDLE
m
.
training_state
=
TrainingState
.
IDLE
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_rebuild_full_params
(
self
)
->
None
:
def
_rebuild_full_params
(
self
,
full_precision
:
bool
=
False
)
->
List
[
Tuple
[
torch
.
Tensor
,
bool
]]:
"""Gather all shards of params."""
"""
Gather all shards of params.
Args:
full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage.
Returns:
a list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
"""
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
self
.
mixed_precision
:
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_cast_fp32_param_shards_to_fp16
()
self
.
_cast_fp32_param_shards_to_fp16
()
for
p
in
self
.
params
:
for
p
in
self
.
params
:
if
not
p
.
_is_sharded
:
if
not
p
.
_is_sharded
:
# e.g., when world_size == 1
if
self
.
mixed_precision
:
if
self
.
mixed_precision
and
not
full_precision
:
p
.
data
=
p
.
_fp16_shard
p
.
data
=
p
.
_fp16_shard
continue
output_tensors
.
append
((
p
.
data
,
True
))
p_size
=
p
.
_full_param_padded
.
size
()
if
p
.
_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage_
(
p
.
_full_param_padded
,
size
=
p_size
)
assert
p_size
.
numel
()
%
self
.
world_size
==
0
if
p
.
_is_sharded
:
# Fill p._full_param_padded with (p.data for each shard in self.world_size)
chunks
=
list
(
p
.
_full_param_padded
.
chunk
(
self
.
world_size
))
dist
.
all_gather
(
chunks
,
p
.
data
,
group
=
self
.
process_group
)
else
:
else
:
p
.
_full_param_padded
.
copy_
(
torch
.
flatten
(
p
.
data
),
non_blocking
=
True
)
output_tensors
.
append
((
p
.
data
,
False
))
continue
p
.
data
=
p
.
_full_param_padded
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data
=
p
.
data
.
to
(
p
.
_full_param_padded
.
device
)
if
self
.
mixed_precision
:
p_size
=
p
.
_full_param_padded
.
size
()
assert
p_size
.
numel
()
%
self
.
world_size
==
0
if
not
self
.
mixed_precision
or
not
full_precision
:
if
p
.
_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage_
(
p
.
_full_param_padded
,
size
=
p_size
)
output_tensor
=
p
.
_full_param_padded
else
:
# Allocate fresh tensor in full precision.
output_tensor
=
p_data
.
new_zeros
(
p_size
)
output_tensors
.
append
((
output_tensor
,
True
))
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks
=
list
(
output_tensor
.
chunk
(
self
.
world_size
))
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
p
.
data
=
output_tensor
[:
p
.
_orig_size
.
numel
()].
view
(
p
.
_orig_size
)
if
self
.
mixed_precision
and
not
full_precision
:
self
.
_free_fp16_param_shard
([
p
])
self
.
_free_fp16_param_shard
([
p
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
return
output_tensors
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_use_full_params
(
self
)
->
None
:
def
_use_full_params
(
self
)
->
None
:
...
@@ -1013,7 +1079,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1013,7 +1079,7 @@ class FullyShardedDataParallel(nn.Module):
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
for
p
in
params
:
for
p
in
params
:
if
not
p
.
_is_sharded
:
if
not
p
.
_is_sharded
:
# e.g., world_size == 1
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
self
.
_free_fp16_param_shard
([
p
])
self
.
_free_fp16_param_shard
([
p
])
continue
continue
...
...
fairscale/nn/misc/flatten_params_wrapper.py
View file @
d2924670
...
@@ -53,9 +53,6 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -53,9 +53,6 @@ class FlattenParamsWrapper(nn.Module):
self
.
_flatten_params
()
self
.
_flatten_params
()
# register the views as plain attributes
self
.
_unflatten_params_as_views
()
# Register hook to be called after state_dict() to remove the
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
# "_fpw_module." prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
...
@@ -70,10 +67,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -70,10 +67,7 @@ class FlattenParamsWrapper(nn.Module):
def
module
(
self
)
->
nn
.
Module
:
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fpw_module
return
self
.
_fpw_module
def
_flatten_params
(
self
)
->
None
:
def
_init_flatten_params
(
self
)
->
List
[
Tensor
]:
assert
not
self
.
is_flattened
self
.
is_flattened
=
True
param_infos
=
[]
param_infos
=
[]
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_infos
=
[]
shared_param_infos
=
[]
...
@@ -102,11 +96,22 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -102,11 +96,22 @@ class FlattenParamsWrapper(nn.Module):
self
.
_param_numels
=
tuple
(
param_numels
)
self
.
_param_numels
=
tuple
(
param_numels
)
self
.
_param_shapes
=
tuple
(
param_shapes
)
self
.
_param_shapes
=
tuple
(
param_shapes
)
return
params
def
_flatten_params
(
self
,
flat_param
:
Optional
[
nn
.
Parameter
]
=
None
)
->
None
:
assert
not
self
.
is_flattened
self
.
is_flattened
=
True
if
not
hasattr
(
self
,
"_param_infos"
):
assert
flat_param
is
None
params
=
self
.
_init_flatten_params
()
flat_param
=
nn
.
Parameter
(
torch
.
cat
([
p
.
reshape
(
-
1
)
for
p
in
params
],
0
))
self
.
param_numel
=
flat_param
.
numel
()
del
params
# flatten
# flatten
flat_param
=
nn
.
Parameter
(
torch
.
cat
([
p
.
reshape
(
-
1
)
for
p
in
params
],
0
))
assert
flat_param
is
not
None
self
.
register_parameter
(
"flat_param"
,
flat_param
)
self
.
register_parameter
(
"flat_param"
,
flat_param
)
self
.
param_numel
=
flat_param
.
numel
()
del
params
# deregister the names as parameters
# deregister the names as parameters
for
m
,
n
in
self
.
_param_infos
:
for
m
,
n
in
self
.
_param_infos
:
...
@@ -114,14 +119,18 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -114,14 +119,18 @@ class FlattenParamsWrapper(nn.Module):
for
m
,
n
,
_
,
_
in
self
.
_shared_param_infos
:
for
m
,
n
,
_
,
_
in
self
.
_shared_param_infos
:
delattr
(
m
,
n
)
delattr
(
m
,
n
)
def
_get_param_views
(
self
)
->
Generator
:
# register the views as plain attributes
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat
_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
)
)
self
.
_un
flat
ten_params_as_views
(
)
def
_unflatten_params
(
self
)
->
None
:
def
_get_param_views
(
self
,
flat_param
:
Tensor
)
->
Generator
:
assert
self
.
is_flattened
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
def
_unflatten_params
(
self
,
flat_param
:
Optional
[
Tensor
]
=
None
)
->
None
:
assert
self
.
is_flattened
or
flat_param
is
not
None
self
.
is_flattened
=
False
self
.
is_flattened
=
False
flat_param
=
flat_param
if
flat_param
is
not
None
else
self
.
flat_param
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
(
flat_param
)
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
if
hasattr
(
m
,
n
):
if
hasattr
(
m
,
n
):
delattr
(
m
,
n
)
delattr
(
m
,
n
)
...
@@ -130,41 +139,60 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -130,41 +139,60 @@ class FlattenParamsWrapper(nn.Module):
if
hasattr
(
m
,
n
):
if
hasattr
(
m
,
n
):
delattr
(
m
,
n
)
delattr
(
m
,
n
)
m
.
register_parameter
(
n
,
getattr
(
shared_m
,
shared_n
))
m
.
register_parameter
(
n
,
getattr
(
shared_m
,
shared_n
))
del
self
.
flat_param
if
hasattr
(
self
,
"flat_param"
):
del
self
.
flat_param
def
_unflatten_params_as_views
(
self
)
->
None
:
def
_unflatten_params_as_views
(
self
)
->
None
:
assert
self
.
is_flattened
assert
self
.
is_flattened
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
(
self
.
flat_param
)
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
setattr
(
m
,
n
,
p
)
# This will set as plain attr
setattr
(
m
,
n
,
p
)
# This will set as plain attr
for
(
m
,
n
,
shared_m
,
shared_n
)
in
self
.
_shared_param_infos
:
for
(
m
,
n
,
shared_m
,
shared_n
)
in
self
.
_shared_param_infos
:
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
@
contextmanager
@
contextmanager
def
unflatten_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
def
unflatten_params
(
self
,
recurse
:
bool
=
True
,
flat_param
:
Optional
[
Tensor
]
=
None
)
->
Generator
:
"""
"""
Unflatten params (optionally recursively on all nested instances).
Unflatten params. If the current instance is already unflattened, then
If the current instance is already unflattened, then it will remain
it will remain unflattened after the context manager exits.
unflattened after the context manager exits.
Args:
recurse (bool, Optional): recursively unflatten all nested instances
(default: True)
flat_param (Tensor, Optional): flat param to use for unflattening.
If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be
appropriately sized and will only be used within the context
manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None).
"""
"""
if
recurse
:
if
recurse
:
with
ExitStack
()
as
stack
:
with
ExitStack
()
as
stack
:
# unflatten any nested FlattenParamsWrapper instances
# unflatten any nested FlattenParamsWrapper instances
for
module
in
self
.
modules
():
for
name
,
module
in
self
.
named_
modules
():
if
isinstance
(
module
,
FlattenParamsWrapper
):
if
isinstance
(
module
,
FlattenParamsWrapper
):
stack
.
enter_context
(
module
.
unflatten_params
(
recurse
=
False
))
is_self
=
name
==
""
stack
.
enter_context
(
module
.
unflatten_params
(
recurse
=
False
,
flat_param
=
flat_param
if
is_self
else
None
)
)
# yield to the caller, with unflattened params in all nested instances
# yield to the caller, with unflattened params in all nested instances
yield
yield
# exiting from the ExitStack will re-flatten params
# exiting from the ExitStack will re-flatten params
return
return
else
:
else
:
assert
(
flat_param
is
None
or
self
.
is_flattened
),
"Unflattening with custom flat_param requires current instance to be flattened"
orig_flattened
=
self
.
is_flattened
orig_flattened
=
self
.
is_flattened
if
self
.
is_flattened
:
if
orig_flattened
:
self
.
_unflatten_params
()
orig_flat_param
=
self
.
flat_param
self
.
_unflatten_params
(
flat_param
)
yield
yield
if
orig_flattened
:
if
orig_flattened
:
self
.
_flatten_params
()
self
.
_flatten_params
(
orig_flat_param
)
self
.
_unflatten_params_as_views
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
"""Forward missing attributes to wrapped module."""
...
...
tests/ci_test_list_1.txt
View file @
d2924670
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/wrap/test_wrap.py
tests/nn/wrap/test_wrap.py
tests/nn/data_parallel/test_fsdp.py
View file @
d2924670
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
#
#
# This source code is licensed under the
MIT
license found in the
# This source code is licensed under the
BSD
license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
functools
import
functools
import
itertools
import
itertools
from
math
import
inf
from
math
import
inf
...
@@ -182,8 +183,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
...
@@ -182,8 +183,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
PyTorch DDP vs. FullyShardedDataParallel.
PyTorch DDP vs. FullyShardedDataParallel.
"""
"""
def
test_nested_all_wrapped_model
(
self
):
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
config
=
{
"mixed_precision"
:
True
}
def
test_nested_wrapped_model
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_identical_outputs
,
NestedWrappedModule
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_all_wrapped_model
(
self
,
config
):
model_fn
=
functools
.
partial
(
NestedWrappedModule
,
wrap_everything
=
True
)
model_fn
=
functools
.
partial
(
NestedWrappedModule
,
wrap_everything
=
True
)
test_fn
=
functools
.
partial
(
self
.
_test_identical_outputs
,
model_fn
,
config
)
test_fn
=
functools
.
partial
(
self
.
_test_identical_outputs
,
model_fn
,
config
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
...
@@ -280,6 +286,9 @@ class TestComparisonToPyTorchDDP(DistributedTest):
...
@@ -280,6 +286,9 @@ class TestComparisonToPyTorchDDP(DistributedTest):
model
=
ref_ddp_fn
(
model
,
group
)
model
=
ref_ddp_fn
(
model
,
group
)
ref_loss
=
cls
.
_train_for_several_steps
(
model
,
num_steps
,
autocast
,
lr
=
lr
,
norm_type
=
norm_type
)
ref_loss
=
cls
.
_train_for_several_steps
(
model
,
num_steps
,
autocast
,
lr
=
lr
,
norm_type
=
norm_type
)
ref_state_dict
=
model
.
module
.
state_dict
()
ref_state_dict
=
model
.
module
.
state_dict
()
if
config
.
get
(
"cpu_offload"
,
False
):
for
k
in
ref_state_dict
.
keys
():
ref_state_dict
[
k
]
=
ref_state_dict
[
k
].
cpu
()
# Confirm we get the same behavior using FullyShardedDataParallel.
# Confirm we get the same behavior using FullyShardedDataParallel.
model
=
FullyShardedDataParallel
(
model_init_fn
(
group
=
group
,
wrapper_config
=
config
),
group
,
**
config
)
model
=
FullyShardedDataParallel
(
model_init_fn
(
group
=
group
,
wrapper_config
=
config
),
group
,
**
config
)
...
@@ -456,9 +465,8 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -456,9 +465,8 @@ class TestSaveLoadStateDict(DistributedTest):
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
sd
=
ddp_model
.
state_dict
()
sd
=
ddp_model
.
state_dict
()
expected_dtype
=
torch
.
float16
if
ddp_model
.
mixed_precision
else
torch
.
float32
wt
=
sd
[
"embed_tokens.weight"
]
wt
=
sd
[
"embed_tokens.weight"
]
assert
wt
.
dtype
==
expected_dtype
,
f
"got dtype
{
wt
.
dtype
}
expected
{
expected_dtype
}
"
assert
wt
.
dtype
==
torch
.
float32
,
f
"got dtype
{
wt
.
dtype
}
expected
torch.float32
"
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
@
classmethod
@
classmethod
...
@@ -480,15 +488,11 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -480,15 +488,11 @@ class TestSaveLoadStateDict(DistributedTest):
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
def
test_nested_wrapped_model
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
...
@@ -501,6 +505,8 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -501,6 +505,8 @@ class TestSaveLoadStateDict(DistributedTest):
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
# Create a nested FSDP-wrapped instance.
if
config
[
"mixed_precision"
]:
config
[
"compute_dtype"
]
=
torch
.
float32
model
=
NestedWrappedModule
(
group
,
config
)
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
...
...
tests/nn/data_parallel/test_fsdp_summon_full_params.py
0 → 100644
View file @
d2924670
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
gc
import
unittest
from
parameterized
import
parameterized
import
torch
from
.test_fsdp
import
CONFIG_OPTIONS
,
DistributedTest
,
rename_test
,
spawn_and_init
def
get_cuda_mem
():
torch
.
cuda
.
synchronize
()
gc
.
collect
()
return
torch
.
cuda
.
memory_allocated
()
class
TestMemory
(
DistributedTest
):
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_memory
(
self
,
config
):
spawn_and_init
(
functools
.
partial
(
self
.
_test_memory
,
config
))
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_memory_volatile
(
self
,
config
):
spawn_and_init
(
functools
.
partial
(
self
.
_test_memory
,
config
,
volatile
=
True
))
@
classmethod
def
_test_memory
(
self
,
config
,
rank
,
group
,
volatile
=
False
):
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
self
.
_train_for_several_steps
(
model
,
1
,
autocast
=
model
.
mixed_precision
)
mems
=
[
get_cuda_mem
()]
with
model
.
summon_full_params
(
volatile
=
volatile
):
mems
.
append
(
get_cuda_mem
())
assert
mems
[
1
]
>=
mems
[
0
]
state_dict
=
model
.
state_dict
()
mems
.
append
(
get_cuda_mem
())
assert
mems
[
2
]
>=
mems
[
1
]
mems
.
append
(
get_cuda_mem
())
assert
mems
[
3
]
<=
mems
[
2
]
del
state_dict
mems
.
append
(
get_cuda_mem
())
assert
mems
[
4
]
==
mems
[
0
]
class
TestPersistence
(
DistributedTest
):
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_non_volatile
(
self
,
config
):
spawn_and_init
(
functools
.
partial
(
self
.
_test_persistence
,
config
))
@
classmethod
def
_test_persistence
(
self
,
config
,
rank
,
group
,
volatile
=
False
):
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
with
model
.
summon_full_params
(
volatile
=
False
):
model
.
module
.
embed_tokens
.
weight
.
data
.
fill_
(
42
)
with
model
.
summon_full_params
():
# non-volatile changes are persisted
assert
torch
.
all
(
model
.
module
.
embed_tokens
.
weight
.
data
==
42.0
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/nn/data_parallel/test_fsdp_uneven.py
View file @
d2924670
...
@@ -29,17 +29,26 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
...
@@ -29,17 +29,26 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
my_lr
=
0.1
my_lr
=
0.1
device
=
torch
.
device
(
"cuda"
)
if
fsdp_config
.
get
(
"mixed_precision"
,
False
):
dtype
=
torch
.
float16
fsdp_config
[
"fp32_reduce_scatter"
]
=
True
else
:
dtype
=
torch
.
float32
if
test_case
[
"assert_ref_out"
]:
if
test_case
[
"assert_ref_out"
]:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Compute one iteration local output.
# Compute one iteration local output.
weight
=
model
.
weight
.
T
.
clone
().
cuda
()
fp32_weight
=
model
.
weight
.
T
.
clone
().
to
(
device
)
v
=
torch
.
Tensor
(
test_case
[
"inputs"
][
0
][
rank
]).
cuda
()
weight
=
fp32_weight
.
to
(
dtype
)
v
=
torch
.
Tensor
(
test_case
[
"inputs"
][
0
][
rank
]).
to
(
device
,
dtype
)
ref_forward_output_my_rank
=
torch
.
matmul
(
v
,
weight
)
ref_forward_output_my_rank
=
torch
.
matmul
(
v
,
weight
)
# Compute one iteration global weight update.
# Compute one iteration global weight update.
v
=
torch
.
Tensor
(
test_case
[
"inputs"
][
0
][:
world_size
]).
cuda
()
v
=
torch
.
Tensor
(
test_case
[
"inputs"
][
0
][:
world_size
]).
to
(
device
,
dtype
)
grad
=
v
.
sum
(
0
).
repeat
(
weight
.
shape
[
0
],
1
).
div
(
world_size
)
grad
=
v
.
float
().
sum
(
0
).
repeat
(
weight
.
shape
[
0
],
1
).
div
(
world_size
)
ref_weight_out
=
weight
-
grad
.
T
*
my_lr
ref_weight_out
=
fp32_weight
-
grad
.
T
*
my_lr
model
.
to
(
"cuda"
)
assert
ref_weight_out
.
dtype
==
torch
.
float32
model
.
to
(
device
)
# not dtype, since FSDP will manage mixed precision internally
assert
isinstance
(
fsdp_config
,
dict
),
str
(
fsdp_config
)
assert
isinstance
(
fsdp_config
,
dict
),
str
(
fsdp_config
)
model
=
FSDP
(
model
,
**
fsdp_config
)
model
=
FSDP
(
model
,
**
fsdp_config
)
optim
=
SGD
(
model
.
parameters
(),
lr
=
my_lr
)
optim
=
SGD
(
model
.
parameters
(),
lr
=
my_lr
)
...
@@ -47,9 +56,9 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
...
@@ -47,9 +56,9 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
assert
len
(
inputs
)
==
1
or
not
test_case
[
"assert_ref_out"
]
assert
len
(
inputs
)
==
1
or
not
test_case
[
"assert_ref_out"
]
assert
len
(
inputs
[
0
])
>=
world_size
assert
len
(
inputs
[
0
])
>=
world_size
for
in_data
in
inputs
:
for
in_data
in
inputs
:
in_data
=
Tensor
(
in_data
[
rank
]).
cuda
(
)
in_data
=
Tensor
(
in_data
[
rank
]).
to
(
device
,
dtype
)
out
=
model
(
in_data
)
out
=
model
(
in_data
)
out
.
sum
().
backward
()
out
.
float
().
sum
().
backward
()
optim
.
step
()
optim
.
step
()
optim
.
zero_grad
()
optim
.
zero_grad
()
if
test_case
[
"assert_ref_out"
]:
if
test_case
[
"assert_ref_out"
]:
...
@@ -70,7 +79,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
...
@@ -70,7 +79,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
@
skip_if_single_gpu
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"test_case"
,
[{
"inputs"
:
[
torch
.
rand
(
8
,
3
)],
"assert_ref_out"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"test_case"
,
[{
"inputs"
:
[
torch
.
rand
(
8
,
3
)],
"assert_ref_out"
:
True
}])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"fsdp_config"
,
[{},
{
"flatten_parameters"
:
False
}],
"fsdp_config"
,
[{},
{
"flatten_parameters"
:
False
}
,
{
"mixed_precision"
:
True
}
],
)
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
list
(
range
(
2
,
9
)))
@
pytest
.
mark
.
parametrize
(
"world_size"
,
list
(
range
(
2
,
9
)))
def
test_one_iteration
(
world_size
,
test_case
,
fsdp_config
):
def
test_one_iteration
(
world_size
,
test_case
,
fsdp_config
):
...
...
tests/nn/misc/test_flatten_params_wrapper.py
View file @
d2924670
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
Test FlattenParamsWrapper
Test FlattenParamsWrapper
"""
"""
from
collections
import
OrderedDict
import
unittest
import
unittest
import
torch
import
torch
...
@@ -196,6 +197,40 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -196,6 +197,40 @@ class TestFlattenParams(unittest.TestCase):
assert
objects_are_equal
(
ref_output
,
new_output
)
assert
objects_are_equal
(
ref_output
,
new_output
)
def
test_unflatten_params
(
self
):
for
module_init_fn
in
self
.
_get_module_init_fns
():
module
=
FlattenParamsWrapper
(
module_init_fn
())
buffers
=
{
k
.
replace
(
"_fpw_module."
,
""
)
for
k
,
_
in
module
.
named_buffers
()}
def
clone_state_dict
():
return
OrderedDict
((
k
,
v
.
clone
())
for
k
,
v
in
module
.
state_dict
().
items
())
ref_flat_param
=
module
.
flat_param
.
clone
()
with
module
.
unflatten_params
():
ref_state_dict
=
clone_state_dict
()
assert
not
torch
.
all
(
ref_flat_param
==
0
)
# confirm that unflatten_params reflects values from new_flat_param
new_flat_param
=
torch
.
full_like
(
module
.
flat_param
,
fill_value
=
42.0
)
with
module
.
unflatten_params
(
flat_param
=
new_flat_param
):
new_state_dict
=
clone_state_dict
()
assert
new_state_dict
.
keys
()
==
ref_state_dict
.
keys
()
for
k
,
v
in
new_state_dict
.
items
():
if
k
in
buffers
:
# buffers are not changed
torch
.
testing
.
assert_allclose
(
v
,
ref_state_dict
[
k
])
else
:
# params reflect new_flat_param value
assert
torch
.
all
(
v
==
42.0
)
# after context manager exits, we go back to previous (reference) state
torch
.
testing
.
assert_allclose
(
module
.
flat_param
,
ref_flat_param
)
with
module
.
unflatten_params
():
ref_state_dict2
=
clone_state_dict
()
assert
objects_are_equal
(
ref_state_dict
,
ref_state_dict2
)
# if we load the new_state_dict, then the flat param should match new_flat_param
module
.
load_state_dict
(
new_state_dict
)
torch
.
testing
.
assert_allclose
(
module
.
flat_param
,
new_flat_param
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
class
TestFlattenParamsCUDA
(
TestFlattenParams
):
class
TestFlattenParamsCUDA
(
TestFlattenParams
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment