Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
635a4a0a
Unverified
Commit
635a4a0a
authored
Oct 08, 2021
by
moto
Committed by
GitHub
Oct 08, 2021
Browse files
Replace `text` with `token` in Tacotron2 API (#1844)
parent
cd8f87bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
23 deletions
+24
-23
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+24
-23
No files found.
torchaudio/models/tacotron2.py
View file @
635a4a0a
...
...
@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
def
forward
(
self
,
t
ext
:
Tensor
,
t
ext
_lengths
:
Tensor
,
t
okens
:
Tensor
,
t
oken
_lengths
:
Tensor
,
mel_specgram
:
Tensor
,
mel_specgram_lengths
:
Tensor
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
r
"""Pass the input through the Tacotron2 model. This is in teacher
forcing mode, which is generally used for training.
The input ``t
ext
`` should be padded with zeros to length max of ``t
ext
_lengths``.
The input ``t
okens
`` should be padded with zeros to length max of ``t
oken
_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
Args:
t
ext
(Tensor): The input t
ext
to Tacotron2 with shape `(n_batch, max of t
ext
_lengths)`.
t
ext
_lengths (Tensor): The length of each
text
with shape `(n_batch, )`.
t
okens
(Tensor): The input t
okens
to Tacotron2 with shape `(n_batch, max of t
oken
_lengths)`.
t
oken
_lengths (Tensor): The
valid
length of each
sample in ``tokens``
with shape `(n_batch, )`.
mel_specgram (Tensor): The target mel spectrogram
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
...
...
@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
Tensor
Sequence of attention weights from the decoder with
shape `(n_batch, max of mel_specgram_lengths, max of t
ext
_lengths)`.
shape `(n_batch, max of mel_specgram_lengths, max of t
oken
_lengths)`.
"""
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
ext
_lengths
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
oken
_lengths
)
mel_specgram
,
gate_outputs
,
alignments
=
self
.
decoder
(
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
ext
_lengths
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
oken
_lengths
)
mel_specgram_postnet
=
self
.
postnet
(
mel_specgram
)
...
...
@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
def
infer
(
self
,
t
ext
:
Tensor
,
text_
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
infer
(
self
,
t
okens
:
Tensor
,
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (
text
) and its corresponding lengths (
text_
lengths). The
sentences (
``tokens``
) and its corresponding lengths (
``
lengths
``
). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `t
ext
` should be padded with zeros to length max of ``
text_
lengths``.
The input `t
okens
` should be padded with zeros to length max of ``lengths``.
Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`.
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
lengths (Tensor or None, optional):
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
Returns:
Tensor, Tensor, and Tensor:
...
...
@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
The length of the predicted mel spectrogram with shape `(n_batch, )`.
Tensor
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of
text_
lengths)`.
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
n_batch
,
max_length
=
t
ext
.
shape
if
text_
lengths
is
None
:
text_
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
ext
.
device
,
t
ext
.
dtype
)
n_batch
,
max_length
=
t
okens
.
shape
if
lengths
is
None
:
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
okens
.
device
,
t
okens
.
dtype
)
assert
text_
lengths
is
not
None
# For TorchScript compiler
assert
lengths
is
not
None
# For TorchScript compiler
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_
lengths
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
lengths
)
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
encoder_outputs
,
text_
lengths
encoder_outputs
,
lengths
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment