Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a6ed6da8
Unverified
Commit
a6ed6da8
authored
Jan 21, 2021
by
Myle Ott
Committed by
GitHub
Jan 21, 2021
Browse files
[fix] lint/typing in FlattenParamsWrapper (#318)
parent
35fdf537
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
25 deletions
+42
-25
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+30
-23
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-1
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
stubs/torch/testing/__init__.pyi
stubs/torch/testing/__init__.pyi
+9
-0
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+1
-1
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
a6ed6da8
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from
collections
import
namedtuple
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
class
FlattenParamsWrapper
(
nn
.
Module
):
...
...
@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module):
if
param_list
is
not
None
:
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
else
:
param_list
=
module
.
parameters
()
param_
li
st
=
set
(
param_list
)
param_list
=
list
(
module
.
parameters
()
)
param_s
e
t
=
set
(
param_list
)
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
self
.
_param_
li
st
=
set
()
self
.
_param_s
e
t
=
set
()
for
m
in
self
.
modules
():
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
if
p
in
param_
li
st
:
self
.
_param_
li
st
.
add
((
m
,
n
))
if
p
in
param_s
e
t
:
self
.
_param_s
e
t
.
add
((
m
,
n
))
self
.
_flatten_params
()
# register the views as plain attributes
self
.
_unflatten_params_as_views
()
def
_flatten_params
(
self
):
def
_flatten_params
(
self
)
->
None
:
param_infos
=
[]
shared_param_memo
=
{}
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_infos
=
[]
params
=
[]
param_numels
=
[]
param_shapes
=
[]
for
m
in
self
.
modules
():
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
if
p
is
not
None
and
(
m
,
n
)
in
self
.
_param_
li
st
:
if
p
is
not
None
and
(
m
,
n
)
in
self
.
_param_s
e
t
:
if
p
in
shared_param_memo
:
shared_m
,
shared_n
=
shared_param_memo
[
p
]
shared_param_infos
.
append
((
m
,
n
,
shared_m
,
shared_n
))
...
...
@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module):
for
m
,
n
,
_
,
_
in
self
.
_shared_param_infos
:
delattr
(
m
,
n
)
def
_get_param_views
(
self
):
def
_get_param_views
(
self
)
->
Generator
:
return
(
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
...
...
@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module):
)
)
def
_unflatten_params
(
self
):
def
_unflatten_params
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
if
hasattr
(
m
,
n
):
...
...
@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module):
m
.
register_parameter
(
n
,
getattr
(
shared_m
,
shared_n
))
del
self
.
flat_param
def
_unflatten_params_as_views
(
self
):
def
_unflatten_params_as_views
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
setattr
(
m
,
n
,
p
)
# This will set as plain attr
...
...
@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module):
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
@
contextmanager
def
unflatten_params
(
self
):
def
unflatten_params
(
self
)
->
Generator
:
self
.
_unflatten_params
()
yield
self
.
_flatten_params
()
self
.
_unflatten_params_as_views
()
def
__getattr__
(
self
,
name
)
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
self
,
*
args
,
unflatten_params
=
True
,
**
kwargs
):
if
unflatten_params
:
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
()
else
:
return
super
().
state_dict
()
def
state_dict
(
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
)
->
OrderedDict
[
str
,
Tensor
]:
"""Return an unflattened state_dict."""
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
state_dict
,
*
args
,
**
kwargs
):
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
if
"flat_param"
in
state_dict
:
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
with
self
.
unflatten_params
():
return
self
.
module
.
load_state_dict
(
state_dict
,
*
args
,
**
kwargs
)
def
forward
(
self
,
*
inputs
,
**
kwinputs
)
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
self
.
_unflatten_params_as_views
()
return
self
.
module
(
*
inputs
,
**
kwinputs
)
fairscale/utils/testing.py
View file @
a6ed6da8
...
...
@@ -366,7 +366,7 @@ class GPT2(nn.Module):
return
self
.
clf_head
(
h
),
logits
def
objects_are_equal
(
a
,
b
,
raise_exception
=
False
)
->
bool
:
def
objects_are_equal
(
a
:
Any
,
b
:
Any
,
raise_exception
:
bool
=
False
)
->
bool
:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
...
...
stubs/torch/__init__.pyi
View file @
a6ed6da8
...
...
@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
from . import cuda as cuda
from . import optim as optim
from . import nn as nn
from . import testing as testing
#MODIFIED BY TORCHGPIPE
from . import backends
...
...
stubs/torch/testing/__init__.pyi
0 → 100644
View file @
a6ed6da8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
tests/nn/misc/test_flatten_params_wrapper.py
View file @
a6ed6da8
...
...
@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase):
flat_module
=
FlattenParamsWrapper
(
flat_module
)
ref_output
=
self
.
_get_output
(
flat_module
)
flat_state_dict
=
flat_module
.
state_dict
(
unflatten_params
=
False
)
flat_state_dict
=
flat_module
.
flat_
state_dict
()
new_module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
new_module
=
FlattenParamsWrapper
(
new_module
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment