Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
976f56e8
Unverified
Commit
976f56e8
authored
Oct 08, 2021
by
moto
Committed by
GitHub
Oct 08, 2021
Browse files
Make `text_length` optional in `Tacotron2.infer` (#1839)
parent
fd7fcf93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+8
-3
No files found.
torchaudio/models/tacotron2.py
View file @
976f56e8
...
...
@@ -1130,7 +1130,7 @@ class Tacotron2(nn.Module):
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
def
infer
(
self
,
text
:
Tensor
,
text_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
infer
(
self
,
text
:
Tensor
,
text_lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
output is the generated mel spectrograms, its corresponding lengths, and
...
...
@@ -1140,7 +1140,8 @@ class Tacotron2(nn.Module):
Args:
text (Tensor): The input text to Tacotron2 with shape (n_batch, max of ``text_lengths``).
text_lengths (Tensor): The length of each text with shape (n_batch, ).
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
Return:
mel_specgram (Tensor): The predicted mel spectrogram
...
...
@@ -1150,6 +1151,11 @@ class Tacotron2(nn.Module):
alignments (Tensor): Sequence of attention weights from the decoder.
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
n_batch
,
max_length
=
text
.
shape
if
text_lengths
is
None
:
text_lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
text
.
device
,
text
.
dtype
)
assert
text_lengths
is
not
None
# For TorchScript compiler
embedded_inputs
=
self
.
embedding
(
text
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_lengths
)
...
...
@@ -1160,7 +1166,6 @@ class Tacotron2(nn.Module):
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
mel_outputs_postnet
=
mel_specgram
+
mel_outputs_postnet
n_batch
=
mel_outputs_postnet
.
size
(
0
)
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment