Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bd5d0496
Unverified
Commit
bd5d0496
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[fix] Lint flattenparams (#320)
* working around broken mypy
parent
a6ed6da8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
37 deletions
+12
-37
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+6
-20
fairscale/utils/testing.py
fairscale/utils/testing.py
+2
-5
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+4
-12
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
bd5d0496
# Copyright (c) Tongzhou Wang
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
# Licensed under the MIT License.
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn
as
nn
class
FlattenParamsWrapper
(
nn
.
Module
):
class
FlattenParamsWrapper
(
nn
.
Module
):
...
@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
appearing in the given list (default: flatten all parameters)
appearing in the given list (default: flatten all parameters)
"""
"""
def
__init__
(
def
__init__
(
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
self
,
module
:
nn
.
Module
,
param_list
:
Optional
[
List
[
nn
.
Parameter
]]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
...
@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
param_shapes
.
append
(
p
.
size
())
param_shapes
.
append
(
p
.
size
())
del
shared_param_memo
del
shared_param_memo
assert
(
assert
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
,
"expects all parameters in module to have same dtype"
len
(
set
(
p
.
dtype
for
p
in
params
))
<=
1
),
"expects all parameters in module to have same dtype"
# store the info for unflatten
# store the info for unflatten
self
.
_param_infos
=
tuple
(
param_infos
)
self
.
_param_infos
=
tuple
(
param_infos
)
...
@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
delattr
(
m
,
n
)
delattr
(
m
,
n
)
def
_get_param_views
(
self
)
->
Generator
:
def
_get_param_views
(
self
)
->
Generator
:
return
(
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
))
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
self
.
flat_param
.
split
(
self
.
_param_numels
),
self
.
_param_shapes
)
)
def
_unflatten_params
(
self
)
->
None
:
def
_unflatten_params
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
...
@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
def
state_dict
(
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
)
->
"OrderedDict[str, Tensor]"
:
# type: ignore
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
)
->
OrderedDict
[
str
,
Tensor
]:
"""Return an unflattened state_dict."""
"""Return an unflattened state_dict."""
with
self
.
unflatten_params
():
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
...
@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
"""Return the flattened state_dict."""
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
if
"flat_param"
in
state_dict
:
if
"flat_param"
in
state_dict
:
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
else
:
...
...
fairscale/utils/testing.py
View file @
bd5d0496
...
@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
...
@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set random seed for reproduc
a
bility."""
"""Set random seed for reproduc
i
bility."""
random
.
seed
(
seed
)
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
try
:
try
:
torch
.
testing
.
assert_allclose
(
a
,
b
)
torch
.
testing
.
assert_allclose
(
a
,
b
)
# assert_allclose doesn't strictly test shape, dtype and device
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match
=
(
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
)
assert
shape_dtype_device_match
assert
shape_dtype_device_match
return
True
return
True
except
AssertionError
as
e
:
except
AssertionError
as
e
:
...
@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return
False
return
False
else
:
else
:
return
a
==
b
return
a
==
b
tests/nn/misc/test_flatten_params_wrapper.py
View file @
bd5d0496
...
@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
...
@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
import
unittest
import
unittest
import
torch
import
torch
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.utils.testing
import
objects_are_equal
from
fairscale.utils.testing
import
objects_are_equal
...
@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
def
_get_transformer
(
self
,
seed
=
0
):
def
_get_transformer
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
# keep everything deterministic
torch
.
manual_seed
(
seed
)
# keep everything deterministic
module
=
torch
.
nn
.
Transformer
(
module
=
torch
.
nn
.
Transformer
(
d_model
=
32
,
d_model
=
32
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
128
,
dropout
=
0.1
,
)
)
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
module
.
register_buffer
(
"dummy_buffer"
,
torch
.
tensor
(
1.0
))
return
module
return
module
...
@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
module
=
self
.
_get_transformer
()
module
=
self
.
_get_transformer
()
num_params
=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
num_params
=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
params_to_flatten
=
(
params_to_flatten
=
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
list
(
module
.
encoder
.
layers
[
1
].
parameters
())
+
list
(
module
.
decoder
.
layers
[
0
].
parameters
())
)
num_params_to_flatten
=
sum
(
p
.
numel
()
for
p
in
params_to_flatten
)
num_params_to_flatten
=
sum
(
p
.
numel
()
for
p
in
params_to_flatten
)
module
=
FlattenParamsWrapper
(
module
,
param_list
=
params_to_flatten
)
module
=
FlattenParamsWrapper
(
module
,
param_list
=
params_to_flatten
)
...
@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
orig_dtype
=
params_to_flatten
[
0
].
dtype
orig_dtype
=
params_to_flatten
[
0
].
dtype
new_dtype
=
torch
.
float32
if
orig_dtype
==
torch
.
float16
else
torch
.
float16
new_dtype
=
torch
.
float32
if
orig_dtype
==
torch
.
float16
else
torch
.
float16
assert
module
.
flat_param
.
dtype
==
orig_dtype
assert
module
.
flat_param
.
dtype
==
orig_dtype
assert
all
(
assert
all
(
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
p
.
dtype
==
orig_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
()
)
module
=
module
.
to
(
dtype
=
new_dtype
)
module
=
module
.
to
(
dtype
=
new_dtype
)
assert
module
.
flat_param
.
dtype
==
new_dtype
assert
module
.
flat_param
.
dtype
==
new_dtype
assert
all
(
p
.
dtype
==
new_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
assert
all
(
p
.
dtype
==
new_dtype
for
p
in
module
.
encoder
.
layers
[
0
].
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment