Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bd5d0496
Unverified
Commit
bd5d0496
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[fix] Lint flattenparams (#320)
* working around broken mypy
parent
a6ed6da8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
37 deletions
+12
-37
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+6
-20
fairscale/utils/testing.py
fairscale/utils/testing.py
+2
-5
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+4
-12
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
bd5d0496
# Copyright (c) Tongzhou Wang
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
# Licensed under the MIT License.
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn
as
nn
class
FlattenParamsWrapper
(
nn
.
Module
):
class
FlattenParamsWrapper
(
nn
.
Module
):
...
@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
appearing in the given list (default: flatten all parameters)
appearing in the given list (default: flatten all parameters)
"""
"""
def
__init__
(
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
...
@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
param_shapes
.
append
(
p
.
size
())
param_shapes
.
append
(
p
.
size
())
del
shared_param_memo
del
shared_param_memo
assert
(
assert
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
,
"expects all parameters in module to have same dtype"
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
),
"expects all parameters in module to have same dtype"
# store the info for unflatten
# store the info for unflatten
self
.
_param_infos
=
tuple
(
param_infos
)
self
.
_param_infos
=
tuple
(
param_infos
)
...
@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
delattr
(
m
,
n
)
delattr
(
m
,
n
)
def
_get_param_views
(
self
)
->
Generator
:
def
_get_param_views
(
self
)
->
Generator
:
return
(
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
)
)
def
_unflatten_params
(
self
)
->
None
:
def
_unflatten_params
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
...
@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
def
state_dict
(
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
)
->
OrderedDict
[
str
,
Tensor
]:
"""Return an unflattened state_dict."""
"""Return an unflattened state_dict."""
with
self
.
unflatten_params
():
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
...
@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
"""Return the flattened state_dict."""
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
if
"flat_param"
in
state_dict
:
if
"flat_param"
in
state_dict
:
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
else
:
...
...
fairscale/utils/testing.py
View file @
bd5d0496
...
@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
...
@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set random seed for reproduc
a
bility."""
"""Set random seed for reproduc
i
bility."""
random
.
seed
(
seed
)
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
try
:
try
:
torch
.
testing
.
assert_allclose
(
a
,
b
)
torch
.
testing
.
assert_allclose
(
a
,
b
)
# assert_allclose doesn't strictly test shape, dtype and device
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match
=
(
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
)
assert
shape_dtype_device_match
assert
shape_dtype_device_match
return
True
return
True
except
AssertionError
as
e
:
except
AssertionError
as
e
:
...
@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return
False
return
False
else
:
else
:
return
a
==
b
return
a
==
b
tests/nn/misc/test_flatten_params_wrapper.py
View file @
bd5d0496
...
@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
...
@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
import
unittest
import
unittest
import
torch
import
torch
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.utils.testing
import
objects_are_equal
from
fairscale.utils.testing
import
objects_are_equal
...
@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
def
_get_transformer
(
self
,
seed
=
0
):
def
_get_transformer
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
# keep everything deterministic
torch
.
manual_seed
(
seed
)
# keep everything deterministic
module
=
torch
.
nn
.
Transformer
(
module
=
torch
.
nn
.
Transformer
(
d_model
=
32
,
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
)
)
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
return
module
return
module
...
@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
module
=
self
.
_get_transformer
()
module
=
self
.
_get_transformer
()
num_params
=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
num_params
=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
params_to_flatten
=
(
params_to_flatten
=
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
)
num_params_to_flatten
=
sum
(
p
.
numel
()
for
p
in
params_to_flatten
)
num_params_to_flatten
=
sum
(
p
.
numel
()
for
p
in
params_to_flatten
)
module
=
FlattenParamsWrapper
(
module
,
param_list
=
params_to_flatten
)
module
=
FlattenParamsWrapper
(
module
,
param_list
=
params_to_flatten
)
...
@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
orig_dtype
=
params_to_flatten
[
0
].
dtype
orig_dtype
=
params_to_flatten
[
0
].
dtype
new_dtype
=
torch
.
float32
if
orig_dtype
==
torch
.
float16
else
torch
.
float16
new_dtype
=
torch
.
float32
if
orig_dtype
==
torch
.
float16
else
torch
.
float16
assert
module
.
flat_param
.
dtype
==
orig_dtype
assert
module
.
flat_param
.
dtype
==
orig_dtype
assert
all
(
assert
all
(
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
()
)
module
=
module
.
to
(
dtype
=
new_dtype
)
module
=
module
.
to
(
dtype
=
new_dtype
)
assert
module
.
flat_param
.
dtype
==
new_dtype
assert
module
.
flat_param
.
dtype
==
new_dtype
assert
all
(
p
.
dtype
==
new_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
assert
all
(
p
.
dtype
==
new_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment