Unverified Commit eb984418 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[VITS] Handle deprecated weight norm (#25946)

parent f435003e
......@@ -357,9 +357,14 @@ class VitsWaveNet(torch.nn.Module):
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(config.wavenet_dropout)
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
else:
weight_norm = nn.utils.weight_norm
if config.speaker_embedding_size != 0:
cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = weight_norm(cond_layer, name="weight")
for i in range(num_layers):
dilation = config.wavenet_dilation_rate**i
......@@ -371,7 +376,7 @@ class VitsWaveNet(torch.nn.Module):
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
......@@ -381,7 +386,7 @@ class VitsWaveNet(torch.nn.Module):
res_skip_channels = config.hidden_size
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, inputs, padding_mask, global_conditioning=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment