Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
506d6209
Unverified
Commit
506d6209
authored
Feb 26, 2021
by
Myle Ott
Committed by
GitHub
Feb 26, 2021
Browse files
[fix] Fix nested FlattenParamsWrapper state_dict/load_state_dict (#434)
parent
9163e381
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
281 additions
and
44 deletions
+281
-44
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+97
-15
fairscale/utils/state_dict.py
fairscale/utils/state_dict.py
+75
-0
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
stubs/torch/cuda/__init__.pyi
stubs/torch/cuda/__init__.pyi
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+2
-1
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+70
-28
tests/utils/test_state_dict.py
tests/utils/test_state_dict.py
+35
-0
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
506d6209
# Copyright (c) Tongzhou Wang
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
# Licensed under the MIT License.
from
contextlib
import
contextmanager
from
contextlib
import
ExitStack
,
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.utils.state_dict
import
replace_by_prefix_
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
from
collections
import
OrderedDict
# noqa: F401
...
@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module):
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
_fpw_module
=
module
self
.
is_flattened
=
False
if
param_list
is
not
None
:
if
param_list
is
not
None
:
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
...
@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes
# register the views as plain attributes
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically unflatten
# params. This defaults to True, but may be set to False if the user
# explicitly requests a flat state dict via flat_state_dict().
self
.
_auto_unflatten_state_dict
=
True
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fpw_module
def
_flatten_params
(
self
)
->
None
:
def
_flatten_params
(
self
)
->
None
:
assert
not
self
.
is_flattened
self
.
is_flattened
=
True
param_infos
=
[]
param_infos
=
[]
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_infos
=
[]
shared_param_infos
=
[]
...
@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module):
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
def
_unflatten_params
(
self
)
->
None
:
def
_unflatten_params
(
self
)
->
None
:
assert
self
.
is_flattened
self
.
is_flattened
=
False
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
if
hasattr
(
m
,
n
):
if
hasattr
(
m
,
n
):
...
@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module):
del
self
.
flat_param
del
self
.
flat_param
def
_unflatten_params_as_views
(
self
)
->
None
:
def
_unflatten_params_as_views
(
self
)
->
None
:
assert
self
.
is_flattened
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
setattr
(
m
,
n
,
p
)
# This will set as plain attr
setattr
(
m
,
n
,
p
)
# This will set as plain attr
...
@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module):
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
@
contextmanager
@
contextmanager
def
unflatten_params
(
self
)
->
Generator
:
def
unflatten_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
self
.
_unflatten_params
()
"""
yield
Unflatten params (optionally recursively on all nested instances).
self
.
_flatten_params
()
If the current instance is already unflattened, then it will remain
self
.
_unflatten_params_as_views
()
unflattened after the context manager exits.
"""
if
recurse
:
with
ExitStack
()
as
stack
:
# unflatten any nested FlattenParamsWrapper instances
for
module
in
self
.
modules
():
if
isinstance
(
module
,
FlattenParamsWrapper
):
stack
.
enter_context
(
module
.
unflatten_params
(
recurse
=
False
))
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else
:
orig_flattened
=
self
.
is_flattened
if
self
.
is_flattened
:
self
.
_unflatten_params
()
yield
if
orig_flattened
:
self
.
_flatten_params
()
self
.
_unflatten_params_as_views
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
"""Forward missing attributes to wrapped module."""
...
@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module):
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
"""Return an unflattened state_dict."""
"""Return the wrapped module's state_dict (unflattened)."""
with
self
.
unflatten_params
():
if
self
.
is_flattened
and
self
.
_auto_unflatten_state_dict
:
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
with
self
.
unflatten_params
(
recurse
=
False
):
return
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
"""Return the flattened state_dict."""
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
assert
self
.
is_flattened
with
ExitStack
()
as
stack
:
# tell any nested FlattenParamsWrapper instances not to auto unflatten
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FlattenParamsWrapper
):
stack
.
enter_context
(
module
.
_no_auto_unflatten_state_dict
())
state_dict
=
self
.
state_dict
(
*
args
,
**
kwargs
)
return
state_dict
@
contextmanager
def
_no_auto_unflatten_state_dict
(
self
)
->
Generator
:
backup
=
self
.
_auto_unflatten_state_dict
self
.
_auto_unflatten_state_dict
=
False
yield
self
.
_auto_unflatten_state_dict
=
backup
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
)
->
NamedTuple
:
if
"flat_param"
in
state_dict
:
"""
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if
self
.
is_flattened
and
"flat_param"
not
in
state_dict
:
with
self
.
unflatten_params
(
recurse
=
True
):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
else
:
with
self
.
unflatten_params
():
return
super
().
load_state_dict
(
state_dict
,
strict
)
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
return
self
.
module
(
*
inputs
,
**
kwinputs
)
return
self
.
module
(
*
inputs
,
**
kwinputs
)
def
_post_state_dict_hook
(
module
:
nn
.
Module
,
state_dict
:
"OrderedDict[str, Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, Tensor]"
:
replace_by_prefix_
(
state_dict
,
prefix
+
"_fpw_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_by_prefix_
(
state_dict
,
prefix
,
prefix
+
"_fpw_module."
)
# flat_param actually needs to move one level up though
flat_param_key
=
prefix
+
"_fpw_module.flat_param"
if
flat_param_key
in
state_dict
:
replace_by_prefix_
(
state_dict
,
flat_param_key
,
prefix
+
"flat_param"
)
fairscale/utils/state_dict.py
0 → 100644
View file @
506d6209
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for manipulating state_dicts."""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
Type
,
Union
from
torch
import
Tensor
,
nn
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
def
find_module_instances
(
module
:
nn
.
Module
,
search_class
:
Type
[
nn
.
Module
])
->
List
[
Tuple
[
str
,
nn
.
Module
]]:
"""
Find all occurrences of a given search_class among the given Modules's
children and return the corresponding paths in the same format as
state_dicts.
Usage::
net = nn.Sequential(
nn.Linear(1, 1),
nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}),
nn.LayerNorm(1)
)
>>> find_module_instances(net, nn.LayerNorm)
[('1.ln.', LayerNorm((1,), eps=1e-05, elementwise_affine=True)), ('2.', LayerNorm((1,), eps=1e-05, elementwise_affine=True))]
>>> find_module_instances(net, nn.Dropout)
[]
>>> find_module_instances(net, nn.Sequential)
[('', Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ModuleDict(
(ln): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
(linear): Linear(in_features=1, out_features=1, bias=True)
)
(2): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
))]
"""
paths
=
[]
def
add_paths_
(
module
:
nn
.
Module
,
prefix
:
str
=
""
)
->
None
:
if
isinstance
(
module
,
search_class
):
paths
.
append
((
prefix
,
module
))
for
name
,
child
in
module
.
named_children
():
add_paths_
(
child
,
prefix
+
name
+
"."
)
add_paths_
(
module
)
return
paths
def
replace_by_prefix_
(
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if
old_prefix
==
new_prefix
:
raise
ValueError
(
"old_prefix and new_prefix must be distinct"
)
for
key
in
list
(
state_dict
.
keys
()):
if
not
key
.
startswith
(
old_prefix
):
continue
new_key
=
new_prefix
+
key
[
len
(
old_prefix
)
:]
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
stubs/torch/__init__.pyi
View file @
506d6209
...
@@ -28,6 +28,7 @@ from . import cuda as cuda
...
@@ -28,6 +28,7 @@ from . import cuda as cuda
from . import optim as optim
from . import optim as optim
from . import nn as nn
from . import nn as nn
from . import testing as testing
from . import testing as testing
from . import utils as utils
#MODIFIED BY TORCHGPIPE
#MODIFIED BY TORCHGPIPE
from . import backends
from . import backends
...
...
stubs/torch/cuda/__init__.pyi
View file @
506d6209
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from typing import Optional, Tuple, Union, Dict, Any
from typing import Optional, Tuple, Union, Dict, Any
import ctypes
import ctypes
from . import amp
from .. import device as _device
from .. import device as _device
def is_available() -> bool: ...
def is_available() -> bool: ...
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
506d6209
...
@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest):
...
@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training
# Assert that parameters were updated since before training
unchanged
=
[]
unchanged
=
[]
buffers
=
{
name
for
name
,
_
in
model
.
module
.
named_buffers
()}
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
unchanged
.
append
(
k
)
...
...
tests/nn/misc/test_flatten_params_wrapper.py
View file @
506d6209
...
@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal
...
@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal
class
TestFlattenParams
(
unittest
.
TestCase
):
class
TestFlattenParams
(
unittest
.
TestCase
):
def
_get_module_init_fns
(
self
):
return
[
self
.
_get_shared_params_transformer
,
self
.
_get_nested_flat_module
,
]
def
_get_transformer
(
self
,
seed
=
0
):
def
_get_transformer
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
# keep everything deterministic
torch
.
manual_seed
(
seed
)
# keep everything deterministic
module
=
torch
.
nn
.
Transformer
(
module
=
torch
.
nn
.
Transformer
(
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
)
)
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
def
get_input
(
device
,
dtype
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
src
=
torch
.
rand
(
20
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
tgt
=
torch
.
rand
(
10
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
return
(
src
,
tgt
)
module
.
get_input
=
get_input
return
module
return
module
def
_get_shared_params_transformer
(
self
,
seed
=
0
):
def
_get_shared_params_transformer
(
self
,
seed
=
0
):
...
@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase):
dec_layer
.
linear2
.
weight
=
enc_layer
.
linear2
.
weight
dec_layer
.
linear2
.
weight
=
enc_layer
.
linear2
.
weight
return
module
return
module
def
_get_nested_flat_module
(
self
,
seed
=
0
):
module
=
torch
.
nn
.
Sequential
(
FlattenParamsWrapper
(
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
4
,
8
),
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
8
,
8
)))
),
FlattenParamsWrapper
(
torch
.
nn
.
Sequential
(
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
8
,
16
)))),
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
16
,
4
)),
)
def
get_input
(
device
,
dtype
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
return
(
torch
.
rand
(
8
,
4
).
to
(
device
=
device
,
dtype
=
dtype
),)
module
.
get_input
=
get_input
return
module
def
_get_output
(
self
,
module
):
def
_get_output
(
self
,
module
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
device
=
next
(
module
.
parameters
()).
device
device
=
next
(
module
.
parameters
()).
device
dtype
=
next
(
module
.
parameters
()).
dtype
dtype
=
next
(
module
.
parameters
()).
dtype
src
=
torch
.
rand
(
20
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
input
=
module
.
get_input
(
device
,
dtype
)
tgt
=
torch
.
rand
(
10
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
return
module
(
*
input
)
return
module
(
src
,
tgt
)
def
_get_pnorm_after_step
(
self
,
module
):
def
_get_pnorm_after_step
(
self
,
module
):
optim
=
torch
.
optim
.
SGD
(
module
.
parameters
(),
lr
=
0.01
)
optim
=
torch
.
optim
.
SGD
(
module
.
parameters
(),
lr
=
0.01
)
...
@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase):
torch
.
testing
.
assert_allclose
(
ref_pnorm_after_step
,
flat_pnorm_after_step
)
torch
.
testing
.
assert_allclose
(
ref_pnorm_after_step
,
flat_pnorm_after_step
)
def
test_state_dict_equality
(
self
):
def
test_state_dict_equality
(
self
):
module
=
self
.
_get_shared_params_transformer
()
"""Test that unflattened state dict matches original (unwrapped) one."""
ref_state_dict
=
module
.
state_dict
()
modules_to_test
=
[
init_fn
()
for
init_fn
in
self
.
_get_module_init_fns
()]
for
module
in
modules_to_test
:
ref_state_dict
=
module
.
state_dict
()
flat_module
=
FlattenParamsWrapper
(
module
)
flat_module
=
FlattenParamsWrapper
(
module
)
flat_state_dict
=
flat_module
.
state_dict
()
flat_state_dict
=
flat_module
.
state_dict
()
assert
objects_are_equal
(
ref_state_dict
,
flat_state_dict
)
assert
(
ref_state_dict
.
keys
()
==
flat_state_dict
.
keys
()
),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
flat_state_dict
.
keys
()
}
"
assert
objects_are_equal
(
ref_state_dict
,
flat_state_dict
),
f
"
{
ref_state_dict
}
!=
{
flat_state_dict
}
"
def
test_load_state_dict
(
self
):
def
test_load_state_dict
(
self
):
module
=
self
.
_get_shared_params_transformer
()
"""Test that original (unwrapped) state_dict can be loaded in wrapped module."""
ref_state_dict
=
module
.
state_dict
()
for
module_init_fn
in
self
.
_get_module_init_fns
():
ref_output
=
self
.
_get_output
(
module
)
module
=
module_init_fn
()
ref_state_dict
=
module
.
state_dict
()
module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
ref_output
=
self
.
_get_output
(
module
)
flat_module
=
FlattenParamsWrapper
(
module
)
flat_module
.
load_state_dict
(
ref_state_dict
)
module
=
module_init_fn
(
seed
=
1234
)
flat_output
=
self
.
_get_output
(
flat_module
)
flat_module
=
FlattenParamsWrapper
(
module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
# This should work without the unflatten_params context manager
flat_module
.
load_state_dict
(
ref_state_dict
)
flat_output
=
self
.
_get_output
(
flat_module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
# And it should work with the context manager too
with
flat_module
.
unflatten_params
():
flat_module
.
load_state_dict
(
ref_state_dict
)
flat_output
=
self
.
_get_output
(
flat_module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
def
test_flat_state_dict
(
self
):
def
test_flat_state_dict
(
self
):
flat_module
=
self
.
_get_shared_params_transformer
()
"""Test that flat state dict can be reloaded and produces the same results."""
flat_module
=
FlattenParamsWrapper
(
flat_module
)
for
module_init_fn
in
self
.
_get_module_init_fns
():
ref_output
=
self
.
_get_output
(
flat_module
)
flat_module
=
FlattenParamsWrapper
(
module_init_fn
())
ref_output
=
self
.
_get_output
(
flat_module
)
flat_state_dict
=
flat_module
.
flat_state_dict
()
flat_state_dict
=
flat_module
.
flat_state_dict
()
new_module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
new_module
=
FlattenParamsWrapper
(
module_init_fn
(
seed
=
1234
))
new_module
=
FlattenParamsWrapper
(
new_module
)
new_module
.
load_state_dict
(
flat_state_dict
)
new_module
.
load_state_dict
(
flat_state_dict
)
new_output
=
self
.
_get_output
(
new_module
)
new_output
=
self
.
_get_output
(
new_module
)
assert
objects_are_equal
(
ref_output
,
new_output
)
assert
objects_are_equal
(
ref_output
,
new_output
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
...
...
tests/utils/test_state_dict.py
0 → 100644
View file @
506d6209
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from state_dict.py. """
import
torch
from
torch
import
nn
from
fairscale.utils.state_dict
import
find_module_instances
,
replace_by_prefix_
def
test_find_module_instances
():
net
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ModuleDict
({
"ln"
:
nn
.
LayerNorm
(
1
),
"linear"
:
nn
.
Linear
(
1
,
1
)}),
nn
.
LayerNorm
(
1
)
)
assert
find_module_instances
(
net
,
nn
.
LayerNorm
)
==
[(
"1.ln."
,
net
[
1
][
"ln"
]),
(
"2."
,
net
[
2
])]
assert
find_module_instances
(
net
,
nn
.
Linear
)
==
[(
"0."
,
net
[
0
]),
(
"1.linear."
,
net
[
1
][
"linear"
])]
assert
find_module_instances
(
net
,
nn
.
Dropout
)
==
[]
assert
find_module_instances
(
net
,
nn
.
Sequential
)
==
[(
""
,
net
)]
def
test_replace_by_prefix
():
state_dict
=
{
"layer.a"
:
torch
.
tensor
(
1
),
"abc.layer.def"
:
torch
.
tensor
(
2
),
"layer.b"
:
torch
.
tensor
(
3
)}
replace_by_prefix_
(
state_dict
,
"layer."
,
"module.layer."
)
assert
state_dict
==
{
"module.layer.a"
:
torch
.
tensor
(
1
),
"abc.layer.def"
:
torch
.
tensor
(
2
),
"module.layer.b"
:
torch
.
tensor
(
3
),
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment