Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eb984418
"tests/vscode:/vscode.git/clone" did not exist on "23c146c38b42d1193849fbd6f2943bf754b7c428"
Unverified
Commit
eb984418
authored
Sep 04, 2023
by
Sanchit Gandhi
Committed by
GitHub
Sep 04, 2023
Browse files
[VITS] Handle deprecated weight norm (#25946)
parent
f435003e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
src/transformers/models/vits/modeling_vits.py
src/transformers/models/vits/modeling_vits.py
+8
-3
No files found.
src/transformers/models/vits/modeling_vits.py
View file @
eb984418
...
...
@@ -357,9 +357,14 @@ class VitsWaveNet(torch.nn.Module):
self
.
res_skip_layers
=
torch
.
nn
.
ModuleList
()
self
.
dropout
=
nn
.
Dropout
(
config
.
wavenet_dropout
)
if
hasattr
(
nn
.
utils
.
parametrizations
,
"weight_norm"
):
weight_norm
=
nn
.
utils
.
parametrizations
.
weight_norm
else
:
weight_norm
=
nn
.
utils
.
weight_norm
if
config
.
speaker_embedding_size
!=
0
:
cond_layer
=
torch
.
nn
.
Conv1d
(
config
.
speaker_embedding_size
,
2
*
config
.
hidden_size
*
num_layers
,
1
)
self
.
cond_layer
=
torch
.
nn
.
utils
.
weight_norm
(
cond_layer
,
name
=
"weight"
)
self
.
cond_layer
=
weight_norm
(
cond_layer
,
name
=
"weight"
)
for
i
in
range
(
num_layers
):
dilation
=
config
.
wavenet_dilation_rate
**
i
...
...
@@ -371,7 +376,7 @@ class VitsWaveNet(torch.nn.Module):
dilation
=
dilation
,
padding
=
padding
,
)
in_layer
=
torch
.
nn
.
utils
.
weight_norm
(
in_layer
,
name
=
"weight"
)
in_layer
=
weight_norm
(
in_layer
,
name
=
"weight"
)
self
.
in_layers
.
append
(
in_layer
)
# last one is not necessary
...
...
@@ -381,7 +386,7 @@ class VitsWaveNet(torch.nn.Module):
res_skip_channels
=
config
.
hidden_size
res_skip_layer
=
torch
.
nn
.
Conv1d
(
config
.
hidden_size
,
res_skip_channels
,
1
)
res_skip_layer
=
torch
.
nn
.
utils
.
weight_norm
(
res_skip_layer
,
name
=
"weight"
)
res_skip_layer
=
weight_norm
(
res_skip_layer
,
name
=
"weight"
)
self
.
res_skip_layers
.
append
(
res_skip_layer
)
def
forward
(
self
,
inputs
,
padding_mask
,
global_conditioning
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment