Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bd5d0496
Unverified
Commit
bd5d0496
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[fix] Lint flattenparams (#320)
* working around broken mypy
parent
a6ed6da8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
37 deletions
+12
-37
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+6
-20
fairscale/utils/testing.py
fairscale/utils/testing.py
+2
-5
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+4
-12
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
bd5d0496
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
import
torch.nn
as
nn
class
FlattenParamsWrapper
(
nn
.
Module
):
...
...
@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
appearing in the given list (default: flatten all parameters)
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
super
().
__init__
()
self
.
module
=
module
...
...
@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
param_shapes
.
append
(
p
.
size
())
del
shared_param_memo
assert
(
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
),
"expects all parameters in module to have same dtype"
assert
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
,
"expects all parameters in module to have same dtype"
# store the info for unflatten
self
.
_param_infos
=
tuple
(
param_infos
)
...
...
@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
delattr
(
m
,
n
)
def
_get_param_views
(
self
)
->
Generator
:
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
)
)
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
def
_unflatten_params
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
...
...
@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
)
->
OrderedDict
[
str
,
Tensor
]:
def
state_dict
(
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
"""Return an unflattened state_dict."""
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
...
...
@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
if
"flat_param"
in
state_dict
:
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
...
...
fairscale/utils/testing.py
View file @
bd5d0496
...
...
@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set random seed for reproduc
a
bility."""
"""Set random seed for reproduc
i
bility."""
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
...
@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
try
:
torch
.
testing
.
assert_allclose
(
a
,
b
)
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match
=
(
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
)
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
assert
shape_dtype_device_match
return
True
except
AssertionError
as
e
:
...
...
@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return
False
else
:
return
a
==
b
tests/nn/misc/test_flatten_params_wrapper.py
View file @
bd5d0496
...
...
@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
import
unittest
import
torch
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.utils.testing
import
objects_are_equal
...
...
@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
def
_get_transformer
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
# keep everything deterministic
module
=
torch
.
nn
.
Transformer
(
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
)
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
return
module
...
...
@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
module
=
self
.
_get_transformer
()
num_params
=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
params_to_flatten
=
(
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
)
params_to_flatten
=
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
num_params_to_flatten
=
sum
(
p
.
numel
()
for
p
in
params_to_flatten
)
module
=
FlattenParamsWrapper
(
module
,
param_list
=
params_to_flatten
)
...
...
@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
orig_dtype
=
params_to_flatten
[
0
].
dtype
new_dtype
=
torch
.
float32
if
orig_dtype
==
torch
.
float16
else
torch
.
float16
assert
module
.
flat_param
.
dtype
==
orig_dtype
assert
all
(
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
()
)
assert
all
(
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
module
=
module
.
to
(
dtype
=
new_dtype
)
assert
module
.
flat_param
.
dtype
==
new_dtype
assert
all
(
p
.
dtype
==
new_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment