Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a6ed6da8
"...text-generation-inference.git" did not exist on "895a341d064c9930b2a9bd60cff0df42f91b52fa"
Unverified
Commit
a6ed6da8
authored
Jan 21, 2021
by
Myle Ott
Committed by
GitHub
Jan 21, 2021
Browse files
[fix] lint/typing in FlattenParamsWrapper (#318)
parent
35fdf537
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
25 deletions
+42
-25
fairscale/nn/misc/flatten_params_wrapper.py
fairscale/nn/misc/flatten_params_wrapper.py
+30
-23
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-1
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
stubs/torch/testing/__init__.pyi
stubs/torch/testing/__init__.pyi
+9
-0
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+1
-1
No files found.
fairscale/nn/misc/flatten_params_wrapper.py
View file @
a6ed6da8
# Copyright (c) Tongzhou Wang
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
# Licensed under the MIT License.
from
collections
import
namedtuple
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
class
FlattenParamsWrapper
(
nn
.
Module
):
class
FlattenParamsWrapper
(
nn
.
Module
):
...
@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module):
if
param_list
is
not
None
:
if
param_list
is
not
None
:
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
assert
len
(
param_list
)
>
0
,
"param_list can't be empty"
else
:
else
:
param_list
=
module
.
parameters
()
param_list
=
list
(
module
.
parameters
()
)
param_
li
st
=
set
(
param_list
)
param_s
e
t
=
set
(
param_list
)
# convert from list of Parameters to set of (Module, name) tuples, which
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
# will survive in case the Parameter instances are reset
self
.
_param_
li
st
=
set
()
self
.
_param_s
e
t
=
set
()
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
if
p
in
param_
li
st
:
if
p
in
param_s
e
t
:
self
.
_param_
li
st
.
add
((
m
,
n
))
self
.
_param_s
e
t
.
add
((
m
,
n
))
self
.
_flatten_params
()
self
.
_flatten_params
()
# register the views as plain attributes
# register the views as plain attributes
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
def
_flatten_params
(
self
):
def
_flatten_params
(
self
)
->
None
:
param_infos
=
[]
param_infos
=
[]
shared_param_memo
=
{}
shared_param_memo
:
Dict
[
nn
.
Parameter
,
Tuple
[
nn
.
Module
,
str
]]
=
{}
shared_param_infos
=
[]
shared_param_infos
=
[]
params
=
[]
params
=
[]
param_numels
=
[]
param_numels
=
[]
param_shapes
=
[]
param_shapes
=
[]
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
for
n
,
p
in
m
.
named_parameters
(
recurse
=
False
):
if
p
is
not
None
and
(
m
,
n
)
in
self
.
_param_
li
st
:
if
p
is
not
None
and
(
m
,
n
)
in
self
.
_param_s
e
t
:
if
p
in
shared_param_memo
:
if
p
in
shared_param_memo
:
shared_m
,
shared_n
=
shared_param_memo
[
p
]
shared_m
,
shared_n
=
shared_param_memo
[
p
]
shared_param_infos
.
append
((
m
,
n
,
shared_m
,
shared_n
))
shared_param_infos
.
append
((
m
,
n
,
shared_m
,
shared_n
))
...
@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module):
for
m
,
n
,
_
,
_
in
self
.
_shared_param_infos
:
for
m
,
n
,
_
,
_
in
self
.
_shared_param_infos
:
delattr
(
m
,
n
)
delattr
(
m
,
n
)
def
_get_param_views
(
self
):
def
_get_param_views
(
self
)
->
Generator
:
return
(
return
(
t
.
view
(
s
)
t
.
view
(
s
)
for
(
t
,
s
)
in
zip
(
for
(
t
,
s
)
in
zip
(
...
@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module):
)
)
)
)
def
_unflatten_params
(
self
):
def
_unflatten_params
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
if
hasattr
(
m
,
n
):
if
hasattr
(
m
,
n
):
...
@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module):
m
.
register_parameter
(
n
,
getattr
(
shared_m
,
shared_n
))
m
.
register_parameter
(
n
,
getattr
(
shared_m
,
shared_n
))
del
self
.
flat_param
del
self
.
flat_param
def
_unflatten_params_as_views
(
self
):
def
_unflatten_params_as_views
(
self
)
->
None
:
ps
=
self
.
_get_param_views
()
ps
=
self
.
_get_param_views
()
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
for
(
m
,
n
),
p
in
zip
(
self
.
_param_infos
,
ps
):
setattr
(
m
,
n
,
p
)
# This will set as plain attr
setattr
(
m
,
n
,
p
)
# This will set as plain attr
...
@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module):
...
@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module):
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
setattr
(
m
,
n
,
getattr
(
shared_m
,
shared_n
))
@
contextmanager
@
contextmanager
def
unflatten_params
(
self
):
def
unflatten_params
(
self
)
->
Generator
:
self
.
_unflatten_params
()
self
.
_unflatten_params
()
yield
yield
self
.
_flatten_params
()
self
.
_flatten_params
()
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
def
__getattr__
(
self
,
name
)
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
"""Forward missing attributes to wrapped module."""
try
:
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
return
getattr
(
self
.
module
,
name
)
# fallback to wrapped module
def
state_dict
(
self
,
*
args
,
unflatten_params
=
True
,
**
kwargs
):
def
state_dict
(
if
unflatten_params
:
self
,
prefix
:
str
=
""
,
keep_vars
:
bool
=
False
,
)
->
OrderedDict
[
str
,
Tensor
]:
"""Return an unflattened state_dict."""
with
self
.
unflatten_params
():
with
self
.
unflatten_params
():
return
self
.
module
.
state_dict
()
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
else
:
return
super
().
state_dict
()
def
flat_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Dict
[
str
,
Any
]:
"""Return the flattened state_dict."""
return
super
().
state_dict
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
state_dict
,
*
args
,
**
kwargs
):
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
if
"flat_param"
in
state_dict
:
if
"flat_param"
in
state_dict
:
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
super
().
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
else
:
with
self
.
unflatten_params
():
with
self
.
unflatten_params
():
return
self
.
module
.
load_state_dict
(
state_dict
,
*
args
,
**
kwargs
)
return
self
.
module
.
load_state_dict
(
state_dict
,
*
args
,
**
kwargs
)
def
forward
(
self
,
*
inputs
,
**
kwinputs
)
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwinputs
:
Any
)
->
Any
:
self
.
_unflatten_params_as_views
()
self
.
_unflatten_params_as_views
()
return
self
.
module
(
*
inputs
,
**
kwinputs
)
return
self
.
module
(
*
inputs
,
**
kwinputs
)
fairscale/utils/testing.py
View file @
a6ed6da8
...
@@ -366,7 +366,7 @@ class GPT2(nn.Module):
...
@@ -366,7 +366,7 @@ class GPT2(nn.Module):
return
self
.
clf_head
(
h
),
logits
return
self
.
clf_head
(
h
),
logits
def
objects_are_equal
(
a
,
b
,
raise_exception
=
False
)
->
bool
:
def
objects_are_equal
(
a
:
Any
,
b
:
Any
,
raise_exception
:
bool
=
False
)
->
bool
:
"""
"""
Test that two objects are equal. Tensors are compared to ensure matching
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
size, dtype, device and values.
...
...
stubs/torch/__init__.pyi
View file @
a6ed6da8
...
@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
...
@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
from . import cuda as cuda
from . import cuda as cuda
from . import optim as optim
from . import optim as optim
from . import nn as nn
from . import nn as nn
from . import testing as testing
#MODIFIED BY TORCHGPIPE
#MODIFIED BY TORCHGPIPE
from . import backends
from . import backends
...
...
stubs/torch/testing/__init__.pyi
0 → 100644
View file @
a6ed6da8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
tests/nn/misc/test_flatten_params_wrapper.py
View file @
a6ed6da8
...
@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase):
...
@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase):
flat_module
=
FlattenParamsWrapper
(
flat_module
)
flat_module
=
FlattenParamsWrapper
(
flat_module
)
ref_output
=
self
.
_get_output
(
flat_module
)
ref_output
=
self
.
_get_output
(
flat_module
)
flat_state_dict
=
flat_module
.
state_dict
(
unflatten_params
=
False
)
flat_state_dict
=
flat_module
.
flat_
state_dict
()
new_module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
new_module
=
self
.
_get_shared_params_transformer
(
seed
=
1234
)
new_module
=
FlattenParamsWrapper
(
new_module
)
new_module
=
FlattenParamsWrapper
(
new_module
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment