Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9f9b6537
Commit
9f9b6537
authored
Oct 08, 2021
by
moto
Browse files
Replace `text` with `token` in Tacotron2 API (#1844)
parent
cb77a86c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
23 deletions
+24
-23
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+24
-23
No files found.
torchaudio/models/tacotron2.py
View file @
9f9b6537
...
@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
...
@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
t
ext
:
Tensor
,
t
okens
:
Tensor
,
t
ext
_lengths
:
Tensor
,
t
oken
_lengths
:
Tensor
,
mel_specgram
:
Tensor
,
mel_specgram
:
Tensor
,
mel_specgram_lengths
:
Tensor
,
mel_specgram_lengths
:
Tensor
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
r
"""Pass the input through the Tacotron2 model. This is in teacher
r
"""Pass the input through the Tacotron2 model. This is in teacher
forcing mode, which is generally used for training.
forcing mode, which is generally used for training.
The input ``t
ext
`` should be padded with zeros to length max of ``t
ext
_lengths``.
The input ``t
okens
`` should be padded with zeros to length max of ``t
oken
_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
Args:
Args:
t
ext
(Tensor): The input t
ext
to Tacotron2 with shape `(n_batch, max of t
ext
_lengths)`.
t
okens
(Tensor): The input t
okens
to Tacotron2 with shape `(n_batch, max of t
oken
_lengths)`.
t
ext
_lengths (Tensor): The length of each
text
with shape `(n_batch, )`.
t
oken
_lengths (Tensor): The
valid
length of each
sample in ``tokens``
with shape `(n_batch, )`.
mel_specgram (Tensor): The target mel spectrogram
mel_specgram (Tensor): The target mel spectrogram
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
...
@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
...
@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
Tensor
Tensor
Sequence of attention weights from the decoder with
Sequence of attention weights from the decoder with
shape `(n_batch, max of mel_specgram_lengths, max of t
ext
_lengths)`.
shape `(n_batch, max of mel_specgram_lengths, max of t
oken
_lengths)`.
"""
"""
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
ext
_lengths
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
t
oken
_lengths
)
mel_specgram
,
gate_outputs
,
alignments
=
self
.
decoder
(
mel_specgram
,
gate_outputs
,
alignments
=
self
.
decoder
(
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
ext
_lengths
encoder_outputs
,
mel_specgram
,
memory_lengths
=
t
oken
_lengths
)
)
mel_specgram_postnet
=
self
.
postnet
(
mel_specgram
)
mel_specgram_postnet
=
self
.
postnet
(
mel_specgram
)
...
@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
...
@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
@
torch
.
jit
.
export
def
infer
(
self
,
t
ext
:
Tensor
,
text_
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
infer
(
self
,
t
okens
:
Tensor
,
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (
text
) and its corresponding lengths (
text_
lengths). The
sentences (
``tokens``
) and its corresponding lengths (
``
lengths
``
). The
output is the generated mel spectrograms, its corresponding lengths, and
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
the attention weights from the decoder.
The input `t
ext
` should be padded with zeros to length max of ``
text_
lengths``.
The input `t
okens
` should be padded with zeros to length max of ``lengths``.
Args:
Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`.
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
lengths (Tensor or None, optional):
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
Returns:
Returns:
Tensor, Tensor, and Tensor:
Tensor, Tensor, and Tensor:
...
@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
...
@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
The length of the predicted mel spectrogram with shape `(n_batch, )`.
The length of the predicted mel spectrogram with shape `(n_batch, )`.
Tensor
Tensor
Sequence of attention weights from the decoder with shape
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of
text_
lengths)`.
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
"""
n_batch
,
max_length
=
t
ext
.
shape
n_batch
,
max_length
=
t
okens
.
shape
if
text_
lengths
is
None
:
if
lengths
is
None
:
text_
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
ext
.
device
,
t
ext
.
dtype
)
lengths
=
torch
.
tensor
([
max_length
]).
expand
(
n_batch
).
to
(
t
okens
.
device
,
t
okens
.
dtype
)
assert
text_
lengths
is
not
None
# For TorchScript compiler
assert
lengths
is
not
None
# For TorchScript compiler
embedded_inputs
=
self
.
embedding
(
t
ext
).
transpose
(
1
,
2
)
embedded_inputs
=
self
.
embedding
(
t
okens
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_
lengths
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
lengths
)
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
encoder_outputs
,
text_
lengths
encoder_outputs
,
lengths
)
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment