Unverified Commit eb984418 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[VITS] Handle deprecated weight norm (#25946)

parent f435003e
...@@ -357,9 +357,14 @@ class VitsWaveNet(torch.nn.Module): ...@@ -357,9 +357,14 @@ class VitsWaveNet(torch.nn.Module):
self.res_skip_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(config.wavenet_dropout) self.dropout = nn.Dropout(config.wavenet_dropout)
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
else:
weight_norm = nn.utils.weight_norm
if config.speaker_embedding_size != 0: if config.speaker_embedding_size != 0:
cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = weight_norm(cond_layer, name="weight")
for i in range(num_layers): for i in range(num_layers):
dilation = config.wavenet_dilation_rate**i dilation = config.wavenet_dilation_rate**i
...@@ -371,7 +376,7 @@ class VitsWaveNet(torch.nn.Module): ...@@ -371,7 +376,7 @@ class VitsWaveNet(torch.nn.Module):
dilation=dilation, dilation=dilation,
padding=padding, padding=padding,
) )
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
# last one is not necessary # last one is not necessary
...@@ -381,7 +386,7 @@ class VitsWaveNet(torch.nn.Module): ...@@ -381,7 +386,7 @@ class VitsWaveNet(torch.nn.Module):
res_skip_channels = config.hidden_size res_skip_channels = config.hidden_size
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, inputs, padding_mask, global_conditioning=None): def forward(self, inputs, padding_mask, global_conditioning=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment