Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
506d6209
Unverified
Commit
506d6209
authored
Feb 26, 2021
by
Myle Ott
Committed by
GitHub
Feb 26, 2021
Browse files
[fix] Fix nested FlattenParamsWrapper state_dict/load_state_dict (#434)
parent
9163e381
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
281 additions
and
44 deletions
+281
-44
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+97
-15
fairscale/utils/state_dict.py
fairscale/utils/state_dict.py
+75
-0
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
stubs/torch/cuda/__init__.pyi
stubs/torch/cuda/__init__.pyi
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+2
-1
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+70
-28
tests/utils/test_state_dict.py
tests/utils/test_state_dict.py
+35
-0
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
506d6209
# Copyright (c) Tongzhou Wang
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
# Licensed under the MIT License.
from
contextlib
import
contextmanager
from
contextlib
import
ExitStack
,
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.utils.state_dict
import
replace_by_prefix_
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
from
collections
import
OrderedDict
# noqa: F401
...
@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module):
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
_fpw_module
=
module
self
.
is_flattened
=
False
if
param_list
is
not
None
:
if
param_list
is
not
None
:
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
...
@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes
# register the views as plain attributes
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically unflatten
# params. This defaults to True, but may be set to False if the user
# explicitly requests a flat state dict via flat_state_dict().
self
.
_auto_unflatten_state_dict
=
True
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fpw_module
def
_flatten_params
(
self
)
->
None
:
def
_flatten_params
(
self
)
->
None
:
assert
not
self
.
is_flattened
self
.
is_flattened
=
True
param_infos
=
[]
param_infos
=
[]
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_infos
=
[]
shared_param_infos
=
[]
...
@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module):
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
def
_unflatten_params
(
self
)
->
None
:
def
_unflatten_params
(
self
)
->
None
:
assert
self
.
is_flattened
self
.
is_flattened
=
False
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
if
hasattr
(
m
,
n
):
if
hasattr
(
m
,
n
):
...
@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module):
del
self
.
flat_param
del
self
.
flat_param
def
_unflatten_params_as_views
(
self
)
->
None
:
def
_unflatten_params_as_views
(
self
)
->
None
:
assert
self
.
is_flattened
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
setattr
(
m
,
n
,
p
)
# This will set as plain attr
setattr
(
m
,
n
,
p
)
# This will set as plain attr
...
@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module):
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
@
contextmanager
@
contextmanager
def
unflatten_params
(
self
)
->
Generator
:
def
unflatten_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
self
.
_unflatten_params
()
"""
yield
Unflatten params (optionally recursively on all nested instances).
self
.
_flatten_params
()
If the current instance is already unflattened, then it will remain
self
.
_unflatten_params_as_views
()
unflattened after the context manager exits.
"""
if
recurse
:
with
ExitStack
()
as
stack
:
# unflatten any nested FlattenParamsWrapper instances
for
module
in
self
.
modules
():
if
isinstance
(
module
,
FlattenParamsWrapper
):
stack
.
enter_context
(
module
.
unflatten_params
(
recurse
=
False
))
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else
:
orig_flattened
=
self
.
is_flattened
if
self
.
is_flattened
:
self
.
_unflatten_params
()
yield
if
orig_flattened
:
self
.
_flatten_params
()
self
.
_unflatten_params_as_views
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
"""Forward missing attributes to wrapped module."""
...
@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module):
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
"""Return an unflattened state_dict."""
"""Return the wrapped module's state_dict (unflattened)."""
with
self
.
unflatten_params
():
if
self
.
is_flattened
and
self
.
_auto_unflatten_state_dict
:
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
with
self
.
unflatten_params
(
recurse
=
False
):
return
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
"""Return the flattened state_dict."""
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
assert
self
.
is_flattened
with
ExitStack
()
as
stack
:
# tell any nested FlattenParamsWrapper instances not to auto unflatten
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FlattenParamsWrapper
):
stack
.
enter_context
(
module
.
_no_auto_unflatten_state_dict
())
state_dict
=
self
.
state_dict
(
*
args
,
**
kwargs
)
return
state_dict
@
contextmanager
def
_no_auto_unflatten_state_dict
(
self
)
->
Generator
:
backup
=
self
.
_auto_unflatten_state_dict
self
.
_auto_unflatten_state_dict
=
False
yield
self
.
_auto_unflatten_state_dict
=
backup
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
)
->
NamedTuple
:
if
"flat_param"
in
state_dict
:
"""
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if
self
.
is_flattened
and
"flat_param"
not
in
state_dict
:
with
self
.
unflatten_params
(
recurse
=
True
):
return
super
().
load_state_dict
(
state_dict
,
strict
)
else
:
else
:
with
self
.
unflatten_params
():
return
super
().
load_state_dict
(
state_dict
,
strict
)
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
return
self
.
module
(
*
inputs
,
**
kwinputs
)
return
self
.
module
(
*
inputs
,
**
kwinputs
)
def
_post_state_dict_hook
(
module
:
nn
.
Module
,
state_dict
:
"OrderedDict[str, Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, Tensor]"
:
replace_by_prefix_
(
state_dict
,
prefix
+
"_fpw_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_by_prefix_
(
state_dict
,
prefix
,
prefix
+
"_fpw_module."
)
# flat_param actually needs to move one level up though
flat_param_key
=
prefix
+
"_fpw_module.flat_param"
if
flat_param_key
in
state_dict
:
replace_by_prefix_
(
state_dict
,
flat_param_key
,
prefix
+
"flat_param"
)
fairscale/utils/state_dict.py
0 → 100644
View file @
506d6209
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for manipulating state_dicts."""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
Type
,
Union
from
torch
import
Tensor
,
nn
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
def
find_module_instances
(
module
:
nn
.
Module
,
search_class
:
Type
[
nn
.
Module
])
->
List
[
Tuple
[
str
,
nn
.
Module
]]:
"""
Find all occurrences of a given search_class among the given Modules's
children and return the corresponding paths in the same format as
state_dicts.
Usage::
net = nn.Sequential(
nn.Linear(1, 1),
nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}),
nn.LayerNorm(1)
)
>>> find_module_instances(net, nn.LayerNorm)
[('1.ln.', LayerNorm((1,), eps=1e-05, elementwise_affine=True)), ('2.', LayerNorm((1,), eps=1e-05, elementwise_affine=True))]
>>> find_module_instances(net, nn.Dropout)
[]
>>> find_module_instances(net, nn.Sequential)
[('', Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ModuleDict(
(ln): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
(linear): Linear(in_features=1, out_features=1, bias=True)
)
(2): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
))]
"""
paths
=
[]
def
add_paths_
(
module
:
nn
.
Module
,
prefix
:
str
=
""
)
->
None
:
if
isinstance
(
module
,
search_class
):
paths
.
append
((
prefix
,
module
))
for
name
,
child
in
module
.
named_children
():
add_paths_
(
child
,
prefix
+
name
+
"."
)
add_paths_
(
module
)
return
paths
def
replace_by_prefix_
(
state_dict
:
Union
[
Dict
[
str
,
Tensor
],
"OrderedDict[str, Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if
old_prefix
==
new_prefix
:
raise
ValueError
(
"old_prefix and new_prefix must be distinct"
)
for
key
in
list
(
state_dict
.
keys
()):
if
not
key
.
startswith
(
old_prefix
):
continue
new_key
=
new_prefix
+
key
[
len
(
old_prefix
)
:]
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
stubs/torch/__init__.pyi
View file @
506d6209
...
@@ -28,6 +28,7 @@ from . import cuda as cuda
...
@@ -28,6 +28,7 @@ from . import cuda as cuda
from . import optim as optim
from . import optim as optim
from . import nn as nn
from . import nn as nn
from . import testing as testing
from . import testing as testing
from . import utils as utils
#MODIFIED BY TORCHGPIPE
#MODIFIED BY TORCHGPIPE
from . import backends
from . import backends
...
...
stubs/torch/cuda/__init__.pyi
View file @
506d6209
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from typing import Optional, Tuple, Union, Dict, Any
from typing import Optional, Tuple, Union, Dict, Any
import ctypes
import ctypes
from . import amp
from .. import device as _device
from .. import device as _device
def is_available() -> bool: ...
def is_available() -> bool: ...
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
506d6209
...
@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest):
...
@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training
# Assert that parameters were updated since before training
unchanged
=
[]
unchanged
=
[]
buffers
=
{
name
for
name
,
_
in
model
.
module
.
named_buffers
()}
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
unchanged
.
append
(
k
)
...
...
tests/nn/misc/test_flatten_params_wrapper.py
View file @
506d6209
...
@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal
...
@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal
class
TestFlattenParams
(
unittest
.
TestCase
):
class
TestFlattenParams
(
unittest
.
TestCase
):
def
_get_module_init_fns
(
self
):
return
[
self
.
_get_shared_params_transformer
,
self
.
_get_nested_flat_module
,
]
def
_get_transformer
(
self
,
seed
=
0
):
def
_get_transformer
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
# keep everything deterministic
torch
.
manual_seed
(
seed
)
# keep everything deterministic
module
=
torch
.
nn
.
Transformer
(
module
=
torch
.
nn
.
Transformer
(
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
)
)
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
def
get_input
(
device
,
dtype
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
src
=
torch
.
rand
(
20
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
tgt
=
torch
.
rand
(
10
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
return
(
src
,
tgt
)
module
.
get_input
=
get_input
return
module
return
module
def
_get_shared_params_transformer
(
self
,
seed
=
0
):
def
_get_shared_params_transformer
(
self
,
seed
=
0
):
...
@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase):
dec_layer
.
linear2
.
weight
=
enc_layer
.
linear2
.
weight
dec_layer
.
linear2
.
weight
=
enc_layer
.
linear2
.
weight
return
module
return
module
def
_get_nested_flat_module
(
self
,
seed
=
0
):
module
=
torch
.
nn
.
Sequential
(
FlattenParamsWrapper
(
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
4
,
8
),
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
8
,
8
)))
),
FlattenParamsWrapper
(
torch
.
nn
.
Sequential
(
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
8
,
16
)))),
FlattenParamsWrapper
(
torch
.
nn
.
Linear
(
16
,
4
)),
)
def
get_input
(
device
,
dtype
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
return
(
torch
.
rand
(
8
,
4
).
to
(
device
=
device
,
dtype
=
dtype
),)
module
.
get_input
=
get_input
return
module
def
_get_output
(
self
,
module
):
def
_get_output
(
self
,
module
):
torch
.
manual_seed
(
1
)
# keep everything deterministic
device
=
next
(
module
.
parameters
()).
device
device
=
next
(
module
.
parameters
()).
device
dtype
=
next
(
module
.
parameters
()).
dtype
dtype
=
next
(
module
.
parameters
()).
dtype
src
=
torch
.
rand
(
20
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
input
=
module
.
get_input
(
device
,
dtype
)
tgt
=
torch
.
rand
(
10
,
8
,
32
).
to
(
device
=
device
,
dtype
=
dtype
)
# T x B x C
return
module
(
*
input
)
return
module
(
src
,
tgt
)
def
_get_pnorm_after_step
(
self
,
module
):
def
_get_pnorm_after_step
(
self
,
module
):
optim
=
torch
.
optim
.
SGD
(
module
.
parameters
(),
lr
=
0.01
)
optim
=
torch
.
optim
.
SGD
(
module
.
parameters
(),
lr
=
0.01
)
...
@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase):
torch
.
testing
.
assert_allclose
(
ref_pnorm_after_step
,
flat_pnorm_after_step
)
torch
.
testing
.
assert_allclose
(
ref_pnorm_after_step
,
flat_pnorm_after_step
)
def
test_state_dict_equality
(
self
):
def
test_state_dict_equality
(
self
):
module
=
self
.
_get_shared_params_transformer
()
"""Test that unflattened state dict matches original (unwrapped) one."""
ref_state_dict
=
module
.
state_dict
()
modules_to_test
=
[
init_fn
()
for
init_fn
in
self
.
_get_module_init_fns
()]
for
module
in
modules_to_test
:
ref_state_dict
=
module
.
state_dict
()
flat_module
=
FlattenParamsWrapper
(
module
)
flat_module
=
FlattenParamsWrapper
(
module
)
flat_state_dict
=
flat_module
.
state_dict
()
flat_state_dict
=
flat_module
.
state_dict
()
assert
objects_are_equal
(
ref_state_dict
,
flat_state_dict
)
assert
(
ref_state_dict
.
keys
()
==
flat_state_dict
.
keys
()
),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
flat_state_dict
.
keys
()
}
"
assert
objects_are_equal
(
ref_state_dict
,
flat_state_dict
),
f
"
{
ref_state_dict
}
!=
{
flat_state_dict
}
"
def
test_load_state_dict
(
self
):
def
test_load_state_dict
(
self
):
module
=
self
.
_get_shared_params_transformer
()
"""Test that original (unwrapped) state_dict can be loaded in wrapped module."""
ref_state_dict
=
module
.
state_dict
()
for
module_init_fn
in
self
.
_get_module_init_fns
():
ref_output
=
self
.
_get_output
(
module
)
module
=
module_init_fn
()
ref_state_dict
=
module
.
state_dict
()
module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
ref_output
=
self
.
_get_output
(
module
)
flat_module
=
FlattenParamsWrapper
(
module
)
flat_module
.
load_state_dict
(
ref_state_dict
)
module
=
module_init_fn
(
seed
=
1234
)
flat_output
=
self
.
_get_output
(
flat_module
)
flat_module
=
FlattenParamsWrapper
(
module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
# This should work without the unflatten_params context manager
flat_module
.
load_state_dict
(
ref_state_dict
)
flat_output
=
self
.
_get_output
(
flat_module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
# And it should work with the context manager too
with
flat_module
.
unflatten_params
():
flat_module
.
load_state_dict
(
ref_state_dict
)
flat_output
=
self
.
_get_output
(
flat_module
)
assert
objects_are_equal
(
ref_output
,
flat_output
)
def
test_flat_state_dict
(
self
):
def
test_flat_state_dict
(
self
):
flat_module
=
self
.
_get_shared_params_transformer
()
"""Test that flat state dict can be reloaded and produces the same results."""
flat_module
=
FlattenParamsWrapper
(
flat_module
)
for
module_init_fn
in
self
.
_get_module_init_fns
():
ref_output
=
self
.
_get_output
(
flat_module
)
flat_module
=
FlattenParamsWrapper
(
module_init_fn
())
ref_output
=
self
.
_get_output
(
flat_module
)
flat_state_dict
=
flat_module
.
flat_state_dict
()
flat_state_dict
=
flat_module
.
flat_state_dict
()
new_module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
new_module
=
FlattenParamsWrapper
(
module_init_fn
(
seed
=
1234
))
new_module
=
FlattenParamsWrapper
(
new_module
)
new_module
.
load_state_dict
(
flat_state_dict
)
new_module
.
load_state_dict
(
flat_state_dict
)
new_output
=
self
.
_get_output
(
new_module
)
new_output
=
self
.
_get_output
(
new_module
)
assert
objects_are_equal
(
ref_output
,
new_output
)
assert
objects_are_equal
(
ref_output
,
new_output
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"test requires a GPU"
)
...
...
tests/utils/test_state_dict.py
0 → 100644
View file @
506d6209
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from state_dict.py. """
import
torch
from
torch
import
nn
from
fairscale.utils.state_dict
import
find_module_instances
,
replace_by_prefix_
def
test_find_module_instances
():
net
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ModuleDict
({
"ln"
:
nn
.
LayerNorm
(
1
),
"linear"
:
nn
.
Linear
(
1
,
1
)}),
nn
.
LayerNorm
(
1
)
)
assert
find_module_instances
(
net
,
nn
.
LayerNorm
)
==
[(
"1.ln."
,
net
[
1
][
"ln"
]),
(
"2."
,
net
[
2
])]
assert
find_module_instances
(
net
,
nn
.
Linear
)
==
[(
"0."
,
net
[
0
]),
(
"1.linear."
,
net
[
1
][
"linear"
])]
assert
find_module_instances
(
net
,
nn
.
Dropout
)
==
[]
assert
find_module_instances
(
net
,
nn
.
Sequential
)
==
[(
""
,
net
)]
def
test_replace_by_prefix
():
state_dict
=
{
"layer.a"
:
torch
.
tensor
(
1
),
"abc.layer.def"
:
torch
.
tensor
(
2
),
"layer.b"
:
torch
.
tensor
(
3
)}
replace_by_prefix_
(
state_dict
,
"layer."
,
"module.layer."
)
assert
state_dict
==
{
"module.layer.a"
:
torch
.
tensor
(
1
),
"abc.layer.def"
:
torch
.
tensor
(
2
),
"module.layer.b"
:
torch
.
tensor
(
3
),
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment