Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9f9b6537
You need to sign in or sign up before continuing.
Commit
9f9b6537
authored
Oct 08, 2021
by
moto
Browse files
Replace `text` with `token` in Tacotron2 API (#1844)
parent
cb77a86c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
23 deletions
+24
-23
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+24
-23
No files found.
torchaudio/models/tacotron2.py
View file @
9f9b6537
...
...
@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
def
forward
(
self
,
t
ext
:
Tensor
,
t
ext
_lengths
:
Tensor
,
t
okens
:
Tensor
,
t
oken
_lengths
:
Tensor
,
mel_specgram
:
Tensor
,
mel_specgram_lengths
:
Tensor
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
r
"""Pass the input through the Tacotron2 model. This is in teacher
forcing mode, which is generally used for training.
The input ``t
ext
`` should be padded with zeros to length max of ``t
ext
_lengths``.
The input ``t
okens
`` should be padded with zeros to length max of ``t
oken
_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
Args:
t
ext
(Tensor): The input t
ext
to Tacotron2 with shape `(n_batch, max of t
ext
_lengths)`.
t
ext
_lengths (Tensor): The length of each
text
with shape `(n_batch, )`.
t
okens
(Tensor): The input t
okens
to Tacotron2 with shape `(n_batch, max of t
oken
_lengths)`.
t
oken
_lengths (Tensor): The
valid
length of each
sample in ``tokens``
with shape `(n_batch, )`.
mel_specgram (Tensor): The target mel spectrogram
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
...
...
@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
Tensor
Sequence of attention weights from the decoder with
shape `(n_batch, max of mel_specgram_lengths, max of t
ext
_lengths)`.
shape `(n_batch, max of mel_specgram_lengths, max of t
oken
_lengths)`.
"""
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
ext
_lengths
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
oken
_lengths
)
mel_specgram
,
gate_outputs
,
alignments
=
self
.
decoder
(
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
ext
_lengths
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
oken
_lengths
)
mel_specgram_postnet
=
self
.
postnet
(
mel_specgram
)
...
...
@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
def
infer
(
self
,
t
ext
:
Tensor
,
text_
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
infer
(
self
,
t
okens
:
Tensor
,
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (
text
) and its corresponding lengths (
text_
lengths). The
sentences (
``tokens``
) and its corresponding lengths (
``
lengths
``
). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `t
ext
` should be padded with zeros to length max of ``
text_
lengths``.
The input `t
okens
` should be padded with zeros to length max of ``lengths``.
Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`.
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
lengths (Tensor or None, optional):
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
Returns:
Tensor, Tensor, and Tensor:
...
...
@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
The length of the predicted mel spectrogram with shape `(n_batch, )`.
Tensor
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of
text_
lengths)`.
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
n_batch
,
max_length
=
t
ext
.
shape
if
text_
lengths
is
None
:
text_
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
ext
.
device
,
t
ext
.
dtype
)
n_batch
,
max_length
=
t
okens
.
shape
if
lengths
is
None
:
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
okens
.
device
,
t
okens
.
dtype
)
assert
text_
lengths
is
not
None
# For TorchScript compiler
assert
lengths
is
not
None
# For TorchScript compiler
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_
lengths
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
lengths
)
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
encoder_outputs
,
text_
lengths
encoder_outputs
,
lengths
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment