Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
adbe075a
Unverified
Commit
adbe075a
authored
Mar 08, 2022
by
Masaki Kozuki
Committed by
GitHub
Mar 08, 2022
Browse files
Deprecate reparameterization module (#1316)
parent
74e04667
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
2 deletions
+17
-2
apex/reparameterization/__init__.py
apex/reparameterization/__init__.py
+15
-2
apex/reparameterization/weight_norm.py
apex/reparameterization/weight_norm.py
+2
-0
No files found.
apex/reparameterization/__init__.py
View file @
adbe075a
import
warnings
from
.weight_norm
import
WeightNorm
from
.reparameterization
import
Reparameterization
def
apply_weight_norm
(
module
,
name
=
''
,
dim
=
0
,
hook_child
=
True
):
r
"""
Applies weight normalization to a parameter in the given module.
...
...
@@ -28,7 +31,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
...
...
@@ -44,6 +47,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
torch.Size([40, 20])
"""
warnings
.
warn
(
"`apply_weight_norm` will be removed by the end of June, 2022."
,
FutureWarning
)
return
apply_reparameterization
(
module
,
reparameterization
=
WeightNorm
,
hook_child
=
hook_child
,
name
=
name
,
dim
=
dim
)
...
...
@@ -58,6 +62,7 @@ def remove_weight_norm(module, name='', remove_all=False):
>>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
warnings
.
warn
(
"`remove_weight_norm` will be removed by the end of June, 2022."
,
FutureWarning
)
return
remove_reparameterization
(
module
,
reparameterization
=
WeightNorm
,
name
=
name
,
remove_all
=
remove_all
)
...
...
@@ -72,7 +77,7 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to perform reparameterization op
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
...
...
@@ -84,6 +89,10 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
Linear (20 -> 40)
"""
warnings
.
warn
(
"`apply_reparameterization` will be removed by the end of June, 2022."
,
FutureWarning
,
)
assert
reparameterization
is
not
None
if
name
!=
''
:
Reparameterization
.
apply
(
module
,
name
,
dim
,
reparameterization
,
hook_child
)
...
...
@@ -107,6 +116,10 @@ def remove_reparameterization(module, reparameterization=Reparameterization,
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m)
"""
warnings
.
warn
(
"`remove_reparameterization` will be removed by the end of June, 2022."
,
FutureWarning
,
)
if
name
!=
''
or
remove_all
:
to_remove
=
[]
for
k
,
hook
in
module
.
_forward_pre_hooks
.
items
():
...
...
apex/reparameterization/weight_norm.py
View file @
adbe075a
import
warnings
import
torch
from
torch.nn.parameter
import
Parameter
from
..fp16_utils
import
Fused_Weight_Norm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment