Unverified Commit adbe075a authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Deprecate reparameterization module (#1316)

parent 74e04667
import warnings
from .weight_norm import WeightNorm
from .reparameterization import Reparameterization
def apply_weight_norm(module, name='', dim=0, hook_child=True):
r"""
Applies weight normalization to a parameter in the given module.
......@@ -28,7 +31,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
......@@ -44,6 +47,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
torch.Size([40, 20])
"""
warnings.warn("`apply_weight_norm` will be removed by the end of June, 2022.", FutureWarning)
return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,
name=name, dim=dim)
......@@ -58,6 +62,7 @@ def remove_weight_norm(module, name='', remove_all=False):
>>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
warnings.warn("`remove_weight_norm` will be removed by the end of June, 2022.", FutureWarning)
return remove_reparameterization(module, reparameterization=WeightNorm,
name=name, remove_all=remove_all)
......@@ -72,7 +77,7 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to perform reparameterization op
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
......@@ -84,6 +89,10 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
Linear (20 -> 40)
"""
warnings.warn(
"`apply_reparameterization` will be removed by the end of June, 2022.",
FutureWarning,
)
assert reparameterization is not None
if name != '':
Reparameterization.apply(module, name, dim, reparameterization, hook_child)
......@@ -107,6 +116,10 @@ def remove_reparameterization(module, reparameterization=Reparameterization,
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m)
"""
warnings.warn(
"`remove_reparameterization` will be removed by the end of June, 2022.",
FutureWarning,
)
if name != '' or remove_all:
to_remove = []
for k, hook in module._forward_pre_hooks.items():
......
import warnings
import torch
from torch.nn.parameter import Parameter
from ..fp16_utils import Fused_Weight_Norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment