Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
49c48f93
Commit
49c48f93
authored
Oct 11, 2021
by
moto
Browse files
Fix the main loop of tacotron2 decoder inference (#1849)
To handle batched input properly.
parent
ab97afa0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
17 deletions
+13
-17
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+13
-17
No files found.
torchaudio/models/tacotron2.py
View file @
49c48f93
...
@@ -904,6 +904,8 @@ class _Decoder(nn.Module):
...
@@ -904,6 +904,8 @@ class _Decoder(nn.Module):
alignments (Tensor): Sequence of attention weights from the decoder
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
"""
batch_size
,
device
=
memory
.
size
(
0
),
memory
.
device
decoder_input
=
self
.
_get_go_frame
(
memory
)
decoder_input
=
self
.
_get_go_frame
(
memory
)
mask
=
_get_mask_from_lengths
(
memory_lengths
)
mask
=
_get_mask_from_lengths
(
memory_lengths
)
...
@@ -918,17 +920,12 @@ class _Decoder(nn.Module):
...
@@ -918,17 +920,12 @@ class _Decoder(nn.Module):
processed_memory
,
processed_memory
,
)
=
self
.
_initialize_decoder_states
(
memory
)
)
=
self
.
_initialize_decoder_states
(
memory
)
mel_specgram_lengths
=
torch
.
ones
(
mel_specgram_lengths
=
torch
.
zeros
([
batch_size
],
dtype
=
torch
.
int32
,
device
=
device
)
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
finished
=
torch
.
zeros
([
batch_size
],
dtype
=
torch
.
bool
,
device
=
device
)
)
not_finished
=
torch
.
ones
(
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
)
mel_specgrams
:
List
[
Tensor
]
=
[]
mel_specgrams
:
List
[
Tensor
]
=
[]
gate_outputs
:
List
[
Tensor
]
=
[]
gate_outputs
:
List
[
Tensor
]
=
[]
alignments
:
List
[
Tensor
]
=
[]
alignments
:
List
[
Tensor
]
=
[]
while
True
:
for
_
in
range
(
self
.
decoder_max_step
)
:
decoder_input
=
self
.
prenet
(
decoder_input
)
decoder_input
=
self
.
prenet
(
decoder_input
)
(
(
mel_specgram
,
mel_specgram
,
...
@@ -957,20 +954,19 @@ class _Decoder(nn.Module):
...
@@ -957,20 +954,19 @@ class _Decoder(nn.Module):
mel_specgrams
.
append
(
mel_specgram
.
unsqueeze
(
0
))
mel_specgrams
.
append
(
mel_specgram
.
unsqueeze
(
0
))
gate_outputs
.
append
(
gate_output
.
transpose
(
0
,
1
))
gate_outputs
.
append
(
gate_output
.
transpose
(
0
,
1
))
alignments
.
append
(
attention_weights
)
alignments
.
append
(
attention_weights
)
mel_specgram_lengths
[
~
finished
]
+=
1
dec
=
torch
.
le
(
torch
.
sigmoid
(
gate_output
),
self
.
gate_threshold
).
to
(
torch
.
int32
).
squeeze
(
1
)
finished
|=
torch
.
sigmoid
(
gate_output
.
squeeze
(
1
))
>
self
.
gate_threshold
if
self
.
decoder_early_stopping
and
torch
.
all
(
finished
):
not_finished
=
not_finished
*
dec
if
self
.
decoder_early_stopping
and
torch
.
sum
(
not_finished
)
==
0
:
break
if
len
(
mel_specgrams
)
==
self
.
decoder_max_step
:
warnings
.
warn
(
"Reached max decoder steps"
)
break
break
mel_specgram_lengths
+=
not_finished
decoder_input
=
mel_specgram
decoder_input
=
mel_specgram
if
len
(
mel_specgrams
)
==
self
.
decoder_max_step
:
warnings
.
warn
(
"Reached max decoder steps. The generated spectrogram might not cover "
"the whole transcript."
)
mel_specgrams
=
torch
.
cat
(
mel_specgrams
,
dim
=
0
)
mel_specgrams
=
torch
.
cat
(
mel_specgrams
,
dim
=
0
)
gate_outputs
=
torch
.
cat
(
gate_outputs
,
dim
=
0
)
gate_outputs
=
torch
.
cat
(
gate_outputs
,
dim
=
0
)
alignments
=
torch
.
cat
(
alignments
,
dim
=
0
)
alignments
=
torch
.
cat
(
alignments
,
dim
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment