Unverified Commit adbe075a authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Deprecate reparameterization module (#1316)

parent 74e04667
import warnings
from .weight_norm import WeightNorm from .weight_norm import WeightNorm
from .reparameterization import Reparameterization from .reparameterization import Reparameterization
def apply_weight_norm(module, name='', dim=0, hook_child=True): def apply_weight_norm(module, name='', dim=0, hook_child=True):
r""" r"""
Applies weight normalization to a parameter in the given module. Applies weight normalization to a parameter in the given module.
...@@ -44,6 +47,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True): ...@@ -44,6 +47,7 @@ def apply_weight_norm(module, name='', dim=0, hook_child=True):
torch.Size([40, 20]) torch.Size([40, 20])
""" """
warnings.warn("`apply_weight_norm` will be removed by the end of June, 2022.", FutureWarning)
return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child, return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,
name=name, dim=dim) name=name, dim=dim)
...@@ -58,6 +62,7 @@ def remove_weight_norm(module, name='', remove_all=False): ...@@ -58,6 +62,7 @@ def remove_weight_norm(module, name='', remove_all=False):
>>> m = apply_weight_norm(nn.Linear(20, 40)) >>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m) >>> remove_weight_norm(m)
""" """
warnings.warn("`remove_weight_norm` will be removed by the end of June, 2022.", FutureWarning)
return remove_reparameterization(module, reparameterization=WeightNorm, return remove_reparameterization(module, reparameterization=WeightNorm,
name=name, remove_all=remove_all) name=name, remove_all=remove_all)
...@@ -84,6 +89,10 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho ...@@ -84,6 +89,10 @@ def apply_reparameterization(module, reparameterization=None, name='', dim=0, ho
Linear (20 -> 40) Linear (20 -> 40)
""" """
warnings.warn(
"`apply_reparameterization` will be removed by the end of June, 2022.",
FutureWarning,
)
assert reparameterization is not None assert reparameterization is not None
if name != '': if name != '':
Reparameterization.apply(module, name, dim, reparameterization, hook_child) Reparameterization.apply(module, name, dim, reparameterization, hook_child)
...@@ -107,6 +116,10 @@ def remove_reparameterization(module, reparameterization=Reparameterization, ...@@ -107,6 +116,10 @@ def remove_reparameterization(module, reparameterization=Reparameterization,
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm) >>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m) >>> remove_reparameterization(m)
""" """
warnings.warn(
"`remove_reparameterization` will be removed by the end of June, 2022.",
FutureWarning,
)
if name != '' or remove_all: if name != '' or remove_all:
to_remove = [] to_remove = []
for k, hook in module._forward_pre_hooks.items(): for k, hook in module._forward_pre_hooks.items():
......
import warnings
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..fp16_utils import Fused_Weight_Norm from ..fp16_utils import Fused_Weight_Norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment