Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
44c30436
Unverified
Commit
44c30436
authored
Mar 08, 2022
by
Masaki Kozuki
Committed by
GitHub
Mar 08, 2022
Browse files
Revert "Deprecate reparameterization module (#1316)" (#1319)
This reverts commit
adbe075a
.
parent
79143c31
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
17 deletions
+2
-17
apex/reparameterization/__init__.py
apex/reparameterization/__init__.py
+2
-15
apex/reparameterization/weight_norm.py
apex/reparameterization/weight_norm.py
+0
-2
No files found.
apex/reparameterization/__init__.py
View file @
44c30436
import
warnings
from
.weight_norm
import
WeightNorm
from
.reparameterization
import
Reparameterization
def
apply_weight_norm
(
module
,
name
=
''
,
dim
=
0
,
hook_child
=
True
):
r
"""
Applies weight normalization to a parameter in the given module.
...
...
@@ -31,7 +28,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
...
...
@@ -47,7 +44,6 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
torch.Size([40, 20])
"""
warnings
.
warn
(
"`apply_weight_norm` will be removed by the end of June, 2022."
,
FutureWarning
)
return
apply_reparameterization
(
module
,
reparameterization
=
WeightNorm
,
hook_child
=
hook_child
,
name
=
name
,
dim
=
dim
)
...
...
@@ -62,7 +58,6 @@ def remove_weight_norm(module, name='', remove_all=False):
>>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
warnings
.
warn
(
"`remove_weight_norm` will be removed by the end of June, 2022."
,
FutureWarning
)
return
remove_reparameterization
(
module
,
reparameterization
=
WeightNorm
,
name
=
name
,
remove_all
=
remove_all
)
...
...
@@ -77,7 +72,7 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to perform reparameterization op
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
...
...
@@ -89,10 +84,6 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
Linear (20 -> 40)
"""
warnings
.
warn
(
"`apply_reparameterization` will be removed by the end of June, 2022."
,
FutureWarning
,
)
assert
reparameterization
is
not
None
if
name
!=
''
:
Reparameterization
.
apply
(
module
,
name
,
dim
,
reparameterization
,
hook_child
)
...
...
@@ -116,10 +107,6 @@ def remove_reparameterization(module, reparameterization=Reparameterization,
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m)
"""
warnings
.
warn
(
"`remove_reparameterization` will be removed by the end of June, 2022."
,
FutureWarning
,
)
if
name
!=
''
or
remove_all
:
to_remove
=
[]
for
k
,
hook
in
module
.
_forward_pre_hooks
.
items
():
...
...
apex/reparameterization/weight_norm.py
View file @
44c30436
import
warnings
import
torch
from
torch.nn.parameter
import
Parameter
from
..fp16_utils
import
Fused_Weight_Norm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment