Commit 144cfcfc authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Migrate weight_norm (#3523)

Summary:
torch.nn.utils.weight_norm is deprecated.
Replacing this with new API

Pull Request resolved: https://github.com/pytorch/audio/pull/3523

Reviewed By: huangruizhe

Differential Revision: D47932384

Pulled By: mthrok

fbshipit-source-id: 344abfa12bd11da779f7fd13b74a1e009a582b52
parent cbfde17b
......@@ -208,18 +208,13 @@ class ConvolutionalPositionalEmbedding(Module):
groups=groups,
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
def __prepare_scriptable__(self):
for hook in self.conv._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
torch.nn.utils.remove_weight_norm(self.conv)
if self.conv.__class__.__name__ == "ParametrizedConv1d":
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight")
return self
def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment