Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
sambert-hifigan_pytorch
Commits
51782715
Commit
51782715
authored
Feb 23, 2024
by
liugh5
Browse files
update
parent
8b4e9acd
Changes
182
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2533 deletions
+0
-2533
kantts/models/sambert/__pycache__/adaptors.cpython-38.pyc
kantts/models/sambert/__pycache__/adaptors.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/alignment.cpython-38.pyc
kantts/models/sambert/__pycache__/alignment.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/attention.cpython-38.pyc
kantts/models/sambert/__pycache__/attention.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/fsmn.cpython-38.pyc
kantts/models/sambert/__pycache__/fsmn.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/kantts_sambert.cpython-38.pyc
.../models/sambert/__pycache__/kantts_sambert.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/kantts_sambert_divide.cpython-38.pyc
.../sambert/__pycache__/kantts_sambert_divide.cpython-38.pyc
+0
-0
kantts/models/sambert/__pycache__/positions.cpython-38.pyc
kantts/models/sambert/__pycache__/positions.cpython-38.pyc
+0
-0
kantts/models/sambert/adaptors.py
kantts/models/sambert/adaptors.py
+0
-141
kantts/models/sambert/alignment.py
kantts/models/sambert/alignment.py
+0
-71
kantts/models/sambert/attention.py
kantts/models/sambert/attention.py
+0
-125
kantts/models/sambert/fsmn.py
kantts/models/sambert/fsmn.py
+0
-124
kantts/models/sambert/kantts_sambert.py
kantts/models/sambert/kantts_sambert.py
+0
-1068
kantts/models/sambert/kantts_sambert_divide.py
kantts/models/sambert/kantts_sambert_divide.py
+0
-883
kantts/models/sambert/positions.py
kantts/models/sambert/positions.py
+0
-98
kantts/models/utils.py
kantts/models/utils.py
+0
-23
kantts/preprocess/__init__.py
kantts/preprocess/__init__.py
+0
-0
kantts/preprocess/__pycache__/__init__.cpython-38.pyc
kantts/preprocess/__pycache__/__init__.cpython-38.pyc
+0
-0
kantts/preprocess/__pycache__/fp_processor.cpython-38.pyc
kantts/preprocess/__pycache__/fp_processor.cpython-38.pyc
+0
-0
kantts/preprocess/audio_processor/__init__.py
kantts/preprocess/audio_processor/__init__.py
+0
-0
kantts/preprocess/audio_processor/__pycache__/__init__.cpython-38.pyc
...ocess/audio_processor/__pycache__/__init__.cpython-38.pyc
+0
-0
No files found.
kantts/models/sambert/__pycache__/adaptors.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/alignment.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/attention.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/fsmn.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/kantts_sambert.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/kantts_sambert_divide.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/__pycache__/positions.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/sambert/adaptors.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
kantts.models.sambert.fsmn
import
FsmnEncoderV2
from
kantts.models.sambert
import
Prenet
class
LengthRegulator
(
nn
.
Module
):
def
__init__
(
self
,
r
=
1
):
super
(
LengthRegulator
,
self
).
__init__
()
self
.
r
=
r
def
forward
(
self
,
inputs
,
durations
,
masks
=
None
):
reps
=
(
durations
+
0.5
).
long
()
output_lens
=
reps
.
sum
(
dim
=
1
)
max_len
=
output_lens
.
max
()
reps_cumsum
=
torch
.
cumsum
(
F
.
pad
(
reps
.
float
(),
(
1
,
0
,
0
,
0
),
value
=
0.0
),
dim
=
1
)[
:,
None
,
:
]
range_
=
torch
.
arange
(
max_len
).
to
(
inputs
.
device
)[
None
,
:,
None
]
mult
=
(
reps_cumsum
[:,
:,
:
-
1
]
<=
range_
)
&
(
reps_cumsum
[:,
:,
1
:]
>
range_
)
mult
=
mult
.
float
()
out
=
torch
.
matmul
(
mult
,
inputs
)
if
masks
is
not
None
:
out
=
out
.
masked_fill
(
masks
.
unsqueeze
(
-
1
),
0.0
)
seq_len
=
out
.
size
(
1
)
padding
=
self
.
r
-
int
(
seq_len
)
%
self
.
r
if
padding
<
self
.
r
:
out
=
F
.
pad
(
out
.
transpose
(
1
,
2
),
(
0
,
padding
,
0
,
0
,
0
,
0
),
value
=
0.0
)
out
=
out
.
transpose
(
1
,
2
)
return
out
,
output_lens
class
VarRnnARPredictor
(
nn
.
Module
):
def
__init__
(
self
,
cond_units
,
prenet_units
,
rnn_units
):
super
(
VarRnnARPredictor
,
self
).
__init__
()
self
.
prenet
=
Prenet
(
1
,
prenet_units
)
self
.
lstm
=
nn
.
LSTM
(
prenet_units
[
-
1
]
+
cond_units
,
rnn_units
,
num_layers
=
2
,
batch_first
=
True
,
bidirectional
=
False
,
)
self
.
fc
=
nn
.
Linear
(
rnn_units
,
1
)
def
forward
(
self
,
inputs
,
cond
,
h
=
None
,
masks
=
None
):
x
=
torch
.
cat
([
self
.
prenet
(
inputs
),
cond
],
dim
=-
1
)
# The input can also be a packed variable length sequence,
# here we just omit it for simplicity due to the mask and uni-directional lstm.
x
,
h_new
=
self
.
lstm
(
x
,
h
)
x
=
self
.
fc
(
x
).
squeeze
(
-
1
)
x
=
F
.
relu
(
x
)
if
masks
is
not
None
:
x
=
x
.
masked_fill
(
masks
,
0.0
)
return
x
,
h_new
def
infer
(
self
,
cond
,
masks
=
None
):
batch_size
,
length
=
cond
.
size
(
0
),
cond
.
size
(
1
)
output
=
[]
x
=
torch
.
zeros
((
batch_size
,
1
)).
to
(
cond
.
device
)
h
=
None
for
i
in
range
(
length
):
x
,
h
=
self
.
forward
(
x
.
unsqueeze
(
1
),
cond
[:,
i
:
i
+
1
,
:],
h
=
h
)
output
.
append
(
x
)
output
=
torch
.
cat
(
output
,
dim
=-
1
)
if
masks
is
not
None
:
output
=
output
.
masked_fill
(
masks
,
0.0
)
return
output
class
VarFsmnRnnNARPredictor
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
):
super
(
VarFsmnRnnNARPredictor
,
self
).
__init__
()
self
.
fsmn
=
FsmnEncoderV2
(
filter_size
,
fsmn_num_layers
,
in_dim
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
)
self
.
blstm
=
nn
.
LSTM
(
num_memory_units
,
lstm_units
,
num_layers
=
1
,
batch_first
=
True
,
bidirectional
=
True
,
)
self
.
fc
=
nn
.
Linear
(
2
*
lstm_units
,
1
)
def
forward
(
self
,
inputs
,
masks
=
None
):
input_lengths
=
None
if
masks
is
not
None
:
input_lengths
=
torch
.
sum
((
~
masks
).
float
(),
dim
=
1
).
long
()
x
=
self
.
fsmn
(
inputs
,
masks
)
if
input_lengths
is
not
None
:
x
=
nn
.
utils
.
rnn
.
pack_padded_sequence
(
x
,
input_lengths
.
tolist
(),
batch_first
=
True
,
enforce_sorted
=
False
)
x
,
_
=
self
.
blstm
(
x
)
x
,
_
=
nn
.
utils
.
rnn
.
pad_packed_sequence
(
x
,
batch_first
=
True
,
total_length
=
inputs
.
size
(
1
)
)
else
:
x
,
_
=
self
.
blstm
(
x
)
x
=
self
.
fc
(
x
).
squeeze
(
-
1
)
if
masks
is
not
None
:
x
=
x
.
masked_fill
(
masks
,
0.0
)
return
x
kantts/models/sambert/alignment.py
deleted
100644 → 0
View file @
8b4e9acd
import
numpy
as
np
import
numba
as
nb
@
nb
.
jit
(
nopython
=
True
)
def
mas
(
attn_map
,
width
=
1
):
# assumes mel x text
opt
=
np
.
zeros_like
(
attn_map
)
attn_map
=
np
.
log
(
attn_map
)
attn_map
[
0
,
1
:]
=
-
np
.
inf
log_p
=
np
.
zeros_like
(
attn_map
)
log_p
[
0
,
:]
=
attn_map
[
0
,
:]
prev_ind
=
np
.
zeros_like
(
attn_map
,
dtype
=
np
.
int64
)
for
i
in
range
(
1
,
attn_map
.
shape
[
0
]):
for
j
in
range
(
attn_map
.
shape
[
1
]):
# for each text dim
prev_j
=
np
.
arange
(
max
(
0
,
j
-
width
),
j
+
1
)
prev_log
=
np
.
array
([
log_p
[
i
-
1
,
prev_idx
]
for
prev_idx
in
prev_j
])
ind
=
np
.
argmax
(
prev_log
)
log_p
[
i
,
j
]
=
attn_map
[
i
,
j
]
+
prev_log
[
ind
]
prev_ind
[
i
,
j
]
=
prev_j
[
ind
]
# now backtrack
curr_text_idx
=
attn_map
.
shape
[
1
]
-
1
for
i
in
range
(
attn_map
.
shape
[
0
]
-
1
,
-
1
,
-
1
):
opt
[
i
,
curr_text_idx
]
=
1
curr_text_idx
=
prev_ind
[
i
,
curr_text_idx
]
opt
[
0
,
curr_text_idx
]
=
1
return
opt
@
nb
.
jit
(
nopython
=
True
)
def
mas_width1
(
attn_map
):
"""mas with hardcoded width=1"""
# assumes mel x text
opt
=
np
.
zeros_like
(
attn_map
)
attn_map
=
np
.
log
(
attn_map
)
attn_map
[
0
,
1
:]
=
-
np
.
inf
log_p
=
np
.
zeros_like
(
attn_map
)
log_p
[
0
,
:]
=
attn_map
[
0
,
:]
prev_ind
=
np
.
zeros_like
(
attn_map
,
dtype
=
np
.
int64
)
for
i
in
range
(
1
,
attn_map
.
shape
[
0
]):
for
j
in
range
(
attn_map
.
shape
[
1
]):
# for each text dim
prev_log
=
log_p
[
i
-
1
,
j
]
prev_j
=
j
if
j
-
1
>=
0
and
log_p
[
i
-
1
,
j
-
1
]
>=
log_p
[
i
-
1
,
j
]:
prev_log
=
log_p
[
i
-
1
,
j
-
1
]
prev_j
=
j
-
1
log_p
[
i
,
j
]
=
attn_map
[
i
,
j
]
+
prev_log
prev_ind
[
i
,
j
]
=
prev_j
# now backtrack
curr_text_idx
=
attn_map
.
shape
[
1
]
-
1
for
i
in
range
(
attn_map
.
shape
[
0
]
-
1
,
-
1
,
-
1
):
opt
[
i
,
curr_text_idx
]
=
1
curr_text_idx
=
prev_ind
[
i
,
curr_text_idx
]
opt
[
0
,
curr_text_idx
]
=
1
return
opt
@
nb
.
jit
(
nopython
=
True
,
parallel
=
True
)
def
b_mas
(
b_attn_map
,
in_lens
,
out_lens
,
width
=
1
):
assert
width
==
1
attn_out
=
np
.
zeros_like
(
b_attn_map
)
for
b
in
nb
.
prange
(
b_attn_map
.
shape
[
0
]):
out
=
mas_width1
(
b_attn_map
[
b
,
0
,
:
out_lens
[
b
],
:
in_lens
[
b
]])
attn_out
[
b
,
0
,
:
out_lens
[
b
],
:
in_lens
[
b
]]
=
out
return
attn_out
kantts/models/sambert/attention.py
deleted
100644 → 0
View file @
8b4e9acd
import
numpy
as
np
import
torch
from
torch
import
nn
class
ConvNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
None
,
dilation
=
1
,
bias
=
True
,
w_init_gain
=
"linear"
,
):
super
(
ConvNorm
,
self
).
__init__
()
if
padding
is
None
:
assert
kernel_size
%
2
==
1
padding
=
int
(
dilation
*
(
kernel_size
-
1
)
/
2
)
self
.
conv
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
w_init_gain
)
)
def
forward
(
self
,
signal
):
conv_signal
=
self
.
conv
(
signal
)
return
conv_signal
class
ConvAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_mel_channels
=
80
,
n_text_channels
=
512
,
n_att_channels
=
80
,
temperature
=
1.0
,
use_query_proj
=
True
,
):
super
(
ConvAttention
,
self
).
__init__
()
self
.
temperature
=
temperature
self
.
att_scaling_factor
=
np
.
sqrt
(
n_att_channels
)
self
.
softmax
=
torch
.
nn
.
Softmax
(
dim
=
3
)
self
.
log_softmax
=
torch
.
nn
.
LogSoftmax
(
dim
=
3
)
self
.
attn_proj
=
torch
.
nn
.
Conv2d
(
n_att_channels
,
1
,
kernel_size
=
1
)
self
.
use_query_proj
=
bool
(
use_query_proj
)
self
.
key_proj
=
nn
.
Sequential
(
ConvNorm
(
n_text_channels
,
n_text_channels
*
2
,
kernel_size
=
3
,
bias
=
True
,
w_init_gain
=
"relu"
,
),
torch
.
nn
.
ReLU
(),
ConvNorm
(
n_text_channels
*
2
,
n_att_channels
,
kernel_size
=
1
,
bias
=
True
),
)
self
.
query_proj
=
nn
.
Sequential
(
ConvNorm
(
n_mel_channels
,
n_mel_channels
*
2
,
kernel_size
=
3
,
bias
=
True
,
w_init_gain
=
"relu"
,
),
torch
.
nn
.
ReLU
(),
ConvNorm
(
n_mel_channels
*
2
,
n_mel_channels
,
kernel_size
=
1
,
bias
=
True
),
torch
.
nn
.
ReLU
(),
ConvNorm
(
n_mel_channels
,
n_att_channels
,
kernel_size
=
1
,
bias
=
True
),
)
def
forward
(
self
,
queries
,
keys
,
mask
=
None
,
attn_prior
=
None
):
"""Attention mechanism for flowtron parallel
Unlike in Flowtron, we have no restrictions such as causality etc,
since we only need this during training.
Args:
queries (torch.tensor): B x C x T1 tensor
(probably going to be mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
keys_enc
=
self
.
key_proj
(
keys
)
# B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
if
self
.
use_query_proj
:
queries_enc
=
self
.
query_proj
(
queries
)
else
:
queries_enc
=
queries
# different ways of computing attn,
# one is isotopic gaussians (per phoneme)
# Simplistic Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn
=
(
queries_enc
[:,
:,
:,
None
]
-
keys_enc
[:,
:,
None
])
**
2
# compute log likelihood from a gaussian
attn
=
-
0.0005
*
attn
.
sum
(
1
,
keepdim
=
True
)
if
attn_prior
is
not
None
:
attn
=
self
.
log_softmax
(
attn
)
+
torch
.
log
(
attn_prior
[:,
None
]
+
1e-8
)
attn_logprob
=
attn
.
clone
()
if
mask
is
not
None
:
attn
.
data
.
masked_fill_
(
mask
.
unsqueeze
(
1
).
unsqueeze
(
1
),
-
float
(
"inf"
))
attn
=
self
.
softmax
(
attn
)
# Softmax along T2
return
attn
,
attn_logprob
kantts/models/sambert/fsmn.py
deleted
100644 → 0
View file @
8b4e9acd
"""
FSMN Pytorch Version
"""
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
FeedForwardNet
(
nn
.
Module
):
""" A two-feed-forward-layer module """
def
__init__
(
self
,
d_in
,
d_hid
,
d_out
,
kernel_size
=
[
1
,
1
],
dropout
=
0.1
):
super
().
__init__
()
# Use Conv1D
# position-wise
self
.
w_1
=
nn
.
Conv1d
(
d_in
,
d_hid
,
kernel_size
=
kernel_size
[
0
],
padding
=
(
kernel_size
[
0
]
-
1
)
//
2
,
)
# position-wise
self
.
w_2
=
nn
.
Conv1d
(
d_hid
,
d_out
,
kernel_size
=
kernel_size
[
1
],
padding
=
(
kernel_size
[
1
]
-
1
)
//
2
,
bias
=
False
,
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
output
=
x
.
transpose
(
1
,
2
)
output
=
F
.
relu
(
self
.
w_1
(
output
))
output
=
self
.
dropout
(
output
)
output
=
self
.
w_2
(
output
)
output
=
output
.
transpose
(
1
,
2
)
return
output
class
MemoryBlockV2
(
nn
.
Module
):
def
__init__
(
self
,
d
,
filter_size
,
shift
,
dropout
=
0.0
):
super
(
MemoryBlockV2
,
self
).
__init__
()
left_padding
=
int
(
round
((
filter_size
-
1
)
/
2
))
right_padding
=
int
((
filter_size
-
1
)
/
2
)
if
shift
>
0
:
left_padding
+=
shift
right_padding
-=
shift
self
.
lp
,
self
.
rp
=
left_padding
,
right_padding
self
.
conv_dw
=
nn
.
Conv1d
(
d
,
d
,
filter_size
,
1
,
0
,
groups
=
d
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
input
,
mask
=
None
):
if
mask
is
not
None
:
input
=
input
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
x
=
F
.
pad
(
input
,
(
0
,
0
,
self
.
lp
,
self
.
rp
,
0
,
0
),
mode
=
"constant"
,
value
=
0.0
)
output
=
(
self
.
conv_dw
(
x
.
contiguous
().
transpose
(
1
,
2
)).
contiguous
().
transpose
(
1
,
2
)
)
output
+=
input
output
=
self
.
dropout
(
output
)
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
return
output
class
FsmnEncoderV2
(
nn
.
Module
):
def
__init__
(
self
,
filter_size
,
fsmn_num_layers
,
input_dim
,
num_memory_units
,
ffn_inner_dim
,
dropout
=
0.0
,
shift
=
0
,
):
super
(
FsmnEncoderV2
,
self
).
__init__
()
self
.
filter_size
=
filter_size
self
.
fsmn_num_layers
=
fsmn_num_layers
self
.
num_memory_units
=
num_memory_units
self
.
ffn_inner_dim
=
ffn_inner_dim
self
.
dropout
=
dropout
self
.
shift
=
shift
if
not
isinstance
(
shift
,
list
):
self
.
shift
=
[
shift
for
_
in
range
(
self
.
fsmn_num_layers
)]
self
.
ffn_lst
=
nn
.
ModuleList
()
self
.
ffn_lst
.
append
(
FeedForwardNet
(
input_dim
,
ffn_inner_dim
,
num_memory_units
,
dropout
=
dropout
)
)
for
i
in
range
(
1
,
fsmn_num_layers
):
self
.
ffn_lst
.
append
(
FeedForwardNet
(
num_memory_units
,
ffn_inner_dim
,
num_memory_units
,
dropout
=
dropout
)
)
self
.
memory_block_lst
=
nn
.
ModuleList
()
for
i
in
range
(
fsmn_num_layers
):
self
.
memory_block_lst
.
append
(
MemoryBlockV2
(
num_memory_units
,
filter_size
,
self
.
shift
[
i
],
dropout
)
)
def
forward
(
self
,
input
,
mask
=
None
):
x
=
F
.
dropout
(
input
,
self
.
dropout
,
self
.
training
)
for
(
ffn
,
memory_block
)
in
zip
(
self
.
ffn_lst
,
self
.
memory_block_lst
):
context
=
ffn
(
x
)
memory
=
memory_block
(
context
,
mask
)
memory
=
F
.
dropout
(
memory
,
self
.
dropout
,
self
.
training
)
if
memory
.
size
(
-
1
)
==
x
.
size
(
-
1
):
memory
+=
x
x
=
memory
return
x
kantts/models/sambert/kantts_sambert.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
kantts.models.sambert
import
FFTBlock
,
PNCABlock
,
Prenet
from
kantts.models.sambert.positions
import
(
SinusoidalPositionEncoder
,
DurSinusoidalPositionEncoder
,
)
from
kantts.models.sambert.adaptors
import
(
LengthRegulator
,
VarFsmnRnnNARPredictor
,
VarRnnARPredictor
,
)
from
kantts.models.sambert.fsmn
import
FsmnEncoderV2
from
kantts.models.sambert.alignment
import
b_mas
from
kantts.models.sambert.attention
import
ConvAttention
from
kantts.models.utils
import
get_mask_from_lengths
class
SelfAttentionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
n_layer
,
d_in
,
d_model
,
n_head
,
d_head
,
d_inner
,
dropout
,
dropout_att
,
dropout_relu
,
position_encoder
,
):
super
(
SelfAttentionEncoder
,
self
).
__init__
()
self
.
d_in
=
d_in
self
.
d_model
=
d_model
self
.
dropout
=
dropout
d_in_lst
=
[
d_in
]
+
[
d_model
]
*
(
n_layer
-
1
)
self
.
fft
=
nn
.
ModuleList
(
[
FFTBlock
(
d
,
d_model
,
n_head
,
d_head
,
d_inner
,
(
3
,
1
),
dropout
,
dropout_att
,
dropout_relu
,
)
for
d
in
d_in_lst
]
)
self
.
ln
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
position_enc
=
position_encoder
def
forward
(
self
,
input
,
mask
=
None
,
return_attns
=
False
):
input
*=
self
.
d_model
**
0.5
if
isinstance
(
self
.
position_enc
,
SinusoidalPositionEncoder
):
input
=
self
.
position_enc
(
input
)
else
:
raise
NotImplementedError
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
enc_slf_attn_list
=
[]
max_len
=
input
.
size
(
1
)
if
mask
is
not
None
:
slf_attn_mask
=
mask
.
unsqueeze
(
1
).
expand
(
-
1
,
max_len
,
-
1
)
else
:
slf_attn_mask
=
None
enc_output
=
input
for
id
,
layer
in
enumerate
(
self
.
fft
):
enc_output
,
enc_slf_attn
=
layer
(
enc_output
,
mask
=
mask
,
slf_attn_mask
=
slf_attn_mask
)
if
return_attns
:
enc_slf_attn_list
+=
[
enc_slf_attn
]
enc_output
=
self
.
ln
(
enc_output
)
return
enc_output
,
enc_slf_attn_list
class
HybridAttentionDecoder
(
nn
.
Module
):
def
__init__
(
self
,
d_in
,
prenet_units
,
n_layer
,
d_model
,
d_mem
,
n_head
,
d_head
,
d_inner
,
dropout
,
dropout_att
,
dropout_relu
,
d_out
,
):
super
(
HybridAttentionDecoder
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
dropout
=
dropout
self
.
prenet
=
Prenet
(
d_in
,
prenet_units
,
d_model
)
self
.
dec_in_proj
=
nn
.
Linear
(
d_model
+
d_mem
,
d_model
)
self
.
pnca
=
nn
.
ModuleList
(
[
PNCABlock
(
d_model
,
d_mem
,
n_head
,
d_head
,
d_inner
,
(
1
,
1
),
dropout
,
dropout_att
,
dropout_relu
,
)
for
_
in
range
(
n_layer
)
]
)
self
.
ln
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
dec_out_proj
=
nn
.
Linear
(
d_model
,
d_out
)
def
reset_state
(
self
):
for
layer
in
self
.
pnca
:
layer
.
reset_state
()
def
get_pnca_attn_mask
(
self
,
device
,
max_len
,
x_band_width
,
h_band_width
,
mask
=
None
):
if
mask
is
not
None
:
pnca_attn_mask
=
mask
.
unsqueeze
(
1
).
expand
(
-
1
,
max_len
,
-
1
)
else
:
pnca_attn_mask
=
None
range_
=
torch
.
arange
(
max_len
).
to
(
device
)
x_start
=
torch
.
clamp_min
(
range_
-
x_band_width
,
0
)[
None
,
None
,
:]
x_end
=
(
range_
+
1
)[
None
,
None
,
:]
h_start
=
range_
[
None
,
None
,
:]
h_end
=
torch
.
clamp_max
(
range_
+
h_band_width
+
1
,
max_len
+
1
)[
None
,
None
,
:]
pnca_x_attn_mask
=
~
(
(
x_start
<=
range_
[
None
,
:,
None
])
&
(
x_end
>
range_
[
None
,
:,
None
])
).
transpose
(
1
,
2
)
pnca_h_attn_mask
=
~
(
(
h_start
<=
range_
[
None
,
:,
None
])
&
(
h_end
>
range_
[
None
,
:,
None
])
).
transpose
(
1
,
2
)
if
pnca_attn_mask
is
not
None
:
pnca_x_attn_mask
=
pnca_x_attn_mask
|
pnca_attn_mask
pnca_h_attn_mask
=
pnca_h_attn_mask
|
pnca_attn_mask
pnca_x_attn_mask
=
pnca_x_attn_mask
.
masked_fill
(
pnca_attn_mask
.
transpose
(
1
,
2
),
False
)
pnca_h_attn_mask
=
pnca_h_attn_mask
.
masked_fill
(
pnca_attn_mask
.
transpose
(
1
,
2
),
False
)
return
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
# must call reset_state before
def
forward
(
self
,
input
,
memory
,
x_band_width
,
h_band_width
,
mask
=
None
,
return_attns
=
False
):
input
=
self
.
prenet
(
input
)
input
=
torch
.
cat
([
memory
,
input
],
dim
=-
1
)
input
=
self
.
dec_in_proj
(
input
)
if
mask
is
not
None
:
input
=
input
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
input
*=
self
.
d_model
**
0.5
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
max_len
=
input
.
size
(
1
)
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
=
self
.
get_pnca_attn_mask
(
input
.
device
,
max_len
,
x_band_width
,
h_band_width
,
mask
)
dec_pnca_attn_x_list
=
[]
dec_pnca_attn_h_list
=
[]
dec_output
=
input
for
id
,
layer
in
enumerate
(
self
.
pnca
):
dec_output
,
dec_pnca_attn_x
,
dec_pnca_attn_h
=
layer
(
dec_output
,
memory
,
mask
=
mask
,
pnca_x_attn_mask
=
pnca_x_attn_mask
,
pnca_h_attn_mask
=
pnca_h_attn_mask
,
)
if
return_attns
:
dec_pnca_attn_x_list
+=
[
dec_pnca_attn_x
]
dec_pnca_attn_h_list
+=
[
dec_pnca_attn_h
]
dec_output
=
self
.
ln
(
dec_output
)
dec_output
=
self
.
dec_out_proj
(
dec_output
)
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
# must call reset_state before when step == 0
def
infer
(
self
,
step
,
input
,
memory
,
x_band_width
,
h_band_width
,
mask
=
None
,
return_attns
=
False
,
):
max_len
=
memory
.
size
(
1
)
input
=
self
.
prenet
(
input
)
input
=
torch
.
cat
([
memory
[:,
step
:
step
+
1
,
:],
input
],
dim
=-
1
)
input
=
self
.
dec_in_proj
(
input
)
input
*=
self
.
d_model
**
0.5
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
=
self
.
get_pnca_attn_mask
(
input
.
device
,
max_len
,
x_band_width
,
h_band_width
,
mask
)
dec_pnca_attn_x_list
=
[]
dec_pnca_attn_h_list
=
[]
dec_output
=
input
for
id
,
layer
in
enumerate
(
self
.
pnca
):
if
mask
is
not
None
:
mask_step
=
mask
[:,
step
:
step
+
1
]
else
:
mask_step
=
None
dec_output
,
dec_pnca_attn_x
,
dec_pnca_attn_h
=
layer
(
dec_output
,
memory
,
mask
=
mask_step
,
pnca_x_attn_mask
=
pnca_x_attn_mask
[:,
step
:
step
+
1
,
:
(
step
+
1
)],
pnca_h_attn_mask
=
pnca_h_attn_mask
[:,
step
:
step
+
1
,
:],
)
if
return_attns
:
dec_pnca_attn_x_list
+=
[
dec_pnca_attn_x
]
dec_pnca_attn_h_list
+=
[
dec_pnca_attn_h
]
dec_output
=
self
.
ln
(
dec_output
)
dec_output
=
self
.
dec_out_proj
(
dec_output
)
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
class
TextFftEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TextFftEncoder
,
self
).
__init__
()
d_emb
=
config
[
"embedding_dim"
]
self
.
using_byte
=
False
if
config
.
get
(
"using_byte"
,
False
):
self
.
using_byte
=
True
nb_ling_byte_index
=
config
[
"byte_index"
]
self
.
byte_index_emb
=
nn
.
Embedding
(
nb_ling_byte_index
,
d_emb
)
else
:
# linguistic unit lookup table
nb_ling_sy
=
config
[
"sy"
]
nb_ling_tone
=
config
[
"tone"
]
nb_ling_syllable_flag
=
config
[
"syllable_flag"
]
nb_ling_ws
=
config
[
"word_segment"
]
self
.
sy_emb
=
nn
.
Embedding
(
nb_ling_sy
,
d_emb
)
self
.
tone_emb
=
nn
.
Embedding
(
nb_ling_tone
,
d_emb
)
self
.
syllable_flag_emb
=
nn
.
Embedding
(
nb_ling_syllable_flag
,
d_emb
)
self
.
ws_emb
=
nn
.
Embedding
(
nb_ling_ws
,
d_emb
)
max_len
=
config
[
"max_len"
]
nb_layers
=
config
[
"encoder_num_layers"
]
nb_heads
=
config
[
"encoder_num_heads"
]
d_model
=
config
[
"encoder_num_units"
]
d_head
=
d_model
//
nb_heads
d_inner
=
config
[
"encoder_ffn_inner_dim"
]
dropout
=
config
[
"encoder_dropout"
]
dropout_attn
=
config
[
"encoder_attention_dropout"
]
dropout_relu
=
config
[
"encoder_relu_dropout"
]
d_proj
=
config
[
"encoder_projection_units"
]
self
.
d_model
=
d_model
position_enc
=
SinusoidalPositionEncoder
(
max_len
,
d_emb
)
self
.
ling_enc
=
SelfAttentionEncoder
(
nb_layers
,
d_emb
,
d_model
,
nb_heads
,
d_head
,
d_inner
,
dropout
,
dropout_attn
,
dropout_relu
,
position_enc
,
)
self
.
ling_proj
=
nn
.
Linear
(
d_model
,
d_proj
,
bias
=
False
)
def
forward
(
self
,
inputs_ling
,
masks
=
None
,
return_attns
=
False
):
# Parse inputs_ling_seq
if
self
.
using_byte
:
inputs_byte_index
=
inputs_ling
[:,
:,
0
]
byte_index_embedding
=
self
.
byte_index_emb
(
inputs_byte_index
)
ling_embedding
=
byte_index_embedding
else
:
inputs_sy
=
inputs_ling
[:,
:,
0
]
inputs_tone
=
inputs_ling
[:,
:,
1
]
inputs_syllable_flag
=
inputs_ling
[:,
:,
2
]
inputs_ws
=
inputs_ling
[:,
:,
3
]
# Lookup table
sy_embedding
=
self
.
sy_emb
(
inputs_sy
)
tone_embedding
=
self
.
tone_emb
(
inputs_tone
)
syllable_flag_embedding
=
self
.
syllable_flag_emb
(
inputs_syllable_flag
)
ws_embedding
=
self
.
ws_emb
(
inputs_ws
)
ling_embedding
=
(
sy_embedding
+
tone_embedding
+
syllable_flag_embedding
+
ws_embedding
)
enc_output
,
enc_slf_attn_list
=
self
.
ling_enc
(
ling_embedding
,
masks
,
return_attns
)
if
hasattr
(
self
,
"ling_proj"
):
enc_output
=
self
.
ling_proj
(
enc_output
)
return
enc_output
,
enc_slf_attn_list
,
ling_embedding
class
VarianceAdaptor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
VarianceAdaptor
,
self
).
__init__
()
input_dim
=
(
config
[
"encoder_projection_units"
]
+
config
[
"emotion_units"
]
+
config
[
"speaker_units"
]
)
filter_size
=
config
[
"predictor_filter_size"
]
fsmn_num_layers
=
config
[
"predictor_fsmn_num_layers"
]
num_memory_units
=
config
[
"predictor_num_memory_units"
]
ffn_inner_dim
=
config
[
"predictor_ffn_inner_dim"
]
dropout
=
config
[
"predictor_dropout"
]
shift
=
config
[
"predictor_shift"
]
lstm_units
=
config
[
"predictor_lstm_units"
]
dur_pred_prenet_units
=
config
[
"dur_pred_prenet_units"
]
dur_pred_lstm_units
=
config
[
"dur_pred_lstm_units"
]
self
.
pitch_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
energy_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
duration_predictor
=
VarRnnARPredictor
(
input_dim
,
dur_pred_prenet_units
,
dur_pred_lstm_units
)
self
.
length_regulator
=
LengthRegulator
(
config
[
"outputs_per_step"
])
self
.
dur_position_encoder
=
DurSinusoidalPositionEncoder
(
config
[
"encoder_projection_units"
],
config
[
"outputs_per_step"
]
)
self
.
pitch_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
self
.
energy_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
def
forward
(
self
,
inputs_text_embedding
,
inputs_emo_embedding
,
inputs_spk_embedding
,
masks
=
None
,
output_masks
=
None
,
duration_targets
=
None
,
pitch_targets
=
None
,
energy_targets
=
None
,
):
batch_size
=
inputs_text_embedding
.
size
(
0
)
variance_predictor_inputs
=
torch
.
cat
(
[
inputs_text_embedding
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
)
pitch_predictions
=
self
.
pitch_predictor
(
variance_predictor_inputs
,
masks
)
energy_predictions
=
self
.
energy_predictor
(
variance_predictor_inputs
,
masks
)
if
pitch_targets
is
not
None
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_predictions
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
if
energy_targets
is
not
None
:
energy_embeddings
=
self
.
energy_emb
(
energy_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
energy_embeddings
=
self
.
energy_emb
(
energy_predictions
.
unsqueeze
(
1
)
).
transpose
(
1
,
2
)
inputs_text_embedding_aug
=
(
inputs_text_embedding
+
pitch_embeddings
+
energy_embeddings
)
duration_predictor_cond
=
torch
.
cat
(
[
inputs_text_embedding_aug
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
,
)
if
duration_targets
is
not
None
:
duration_predictor_go_frame
=
torch
.
zeros
(
batch_size
,
1
).
to
(
inputs_text_embedding
.
device
)
duration_predictor_input
=
torch
.
cat
(
[
duration_predictor_go_frame
,
duration_targets
[:,
:
-
1
].
float
()],
dim
=-
1
)
duration_predictor_input
=
torch
.
log
(
duration_predictor_input
+
1
)
log_duration_predictions
,
_
=
self
.
duration_predictor
(
duration_predictor_input
.
unsqueeze
(
-
1
),
duration_predictor_cond
,
masks
=
masks
,
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
else
:
log_duration_predictions
=
self
.
duration_predictor
.
infer
(
duration_predictor_cond
,
masks
=
masks
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
if
duration_targets
is
not
None
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_targets
,
masks
=
output_masks
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_targets
,
masks
=
output_masks
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_targets
,
masks
=
output_masks
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_targets
,
masks
=
output_masks
)
else
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_predictions
,
masks
=
output_masks
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_predictions
,
masks
=
output_masks
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_predictions
,
masks
=
output_masks
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_predictions
,
masks
=
output_masks
)
LR_text_outputs
=
LR_text_outputs
+
LR_position_embeddings
return
(
LR_text_outputs
,
LR_emo_outputs
,
LR_spk_outputs
,
LR_length_rounded
,
log_duration_predictions
,
pitch_predictions
,
energy_predictions
,
)
class
MelPNCADecoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
MelPNCADecoder
,
self
).
__init__
()
prenet_units
=
config
[
"decoder_prenet_units"
]
nb_layers
=
config
[
"decoder_num_layers"
]
nb_heads
=
config
[
"decoder_num_heads"
]
d_model
=
config
[
"decoder_num_units"
]
d_head
=
d_model
//
nb_heads
d_inner
=
config
[
"decoder_ffn_inner_dim"
]
dropout
=
config
[
"decoder_dropout"
]
dropout_attn
=
config
[
"decoder_attention_dropout"
]
dropout_relu
=
config
[
"decoder_relu_dropout"
]
outputs_per_step
=
config
[
"outputs_per_step"
]
d_mem
=
(
config
[
"encoder_projection_units"
]
*
outputs_per_step
+
config
[
"emotion_units"
]
+
config
[
"speaker_units"
]
)
d_mel
=
config
[
"num_mels"
]
self
.
d_mel
=
d_mel
self
.
r
=
outputs_per_step
self
.
nb_layers
=
nb_layers
self
.
mel_dec
=
HybridAttentionDecoder
(
d_mel
,
prenet_units
,
nb_layers
,
d_model
,
d_mem
,
nb_heads
,
d_head
,
d_inner
,
dropout
,
dropout_attn
,
dropout_relu
,
d_mel
*
outputs_per_step
,
)
def
forward
(
self
,
memory
,
x_band_width
,
h_band_width
,
target
=
None
,
mask
=
None
,
return_attns
=
False
,
):
batch_size
=
memory
.
size
(
0
)
go_frame
=
torch
.
zeros
((
batch_size
,
1
,
self
.
d_mel
)).
to
(
memory
.
device
)
if
target
is
not
None
:
self
.
mel_dec
.
reset_state
()
input
=
target
[:,
self
.
r
-
1
::
self
.
r
,
:]
input
=
torch
.
cat
([
go_frame
,
input
],
dim
=
1
)[:,
:
-
1
,
:]
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
=
self
.
mel_dec
(
input
,
memory
,
x_band_width
,
h_band_width
,
mask
=
mask
,
return_attns
=
return_attns
,
)
else
:
dec_output
=
[]
dec_pnca_attn_x_list
=
[[]
for
_
in
range
(
self
.
nb_layers
)]
dec_pnca_attn_h_list
=
[[]
for
_
in
range
(
self
.
nb_layers
)]
self
.
mel_dec
.
reset_state
()
input
=
go_frame
for
step
in
range
(
memory
.
size
(
1
)):
(
dec_output_step
,
dec_pnca_attn_x_step
,
dec_pnca_attn_h_step
,
)
=
self
.
mel_dec
.
infer
(
step
,
input
,
memory
,
x_band_width
,
h_band_width
,
mask
=
mask
,
return_attns
=
return_attns
,
)
input
=
dec_output_step
[:,
:,
-
self
.
d_mel
:]
dec_output
.
append
(
dec_output_step
)
for
layer_id
,
(
pnca_x_attn
,
pnca_h_attn
)
in
enumerate
(
zip
(
dec_pnca_attn_x_step
,
dec_pnca_attn_h_step
)
):
left
=
memory
.
size
(
1
)
-
pnca_x_attn
.
size
(
-
1
)
if
left
>
0
:
padding
=
torch
.
zeros
((
pnca_x_attn
.
size
(
0
),
1
,
left
)).
to
(
pnca_x_attn
)
pnca_x_attn
=
torch
.
cat
([
pnca_x_attn
,
padding
],
dim
=-
1
)
dec_pnca_attn_x_list
[
layer_id
].
append
(
pnca_x_attn
)
dec_pnca_attn_h_list
[
layer_id
].
append
(
pnca_h_attn
)
dec_output
=
torch
.
cat
(
dec_output
,
dim
=
1
)
for
layer_id
in
range
(
self
.
nb_layers
):
dec_pnca_attn_x_list
[
layer_id
]
=
torch
.
cat
(
dec_pnca_attn_x_list
[
layer_id
],
dim
=
1
)
dec_pnca_attn_h_list
[
layer_id
]
=
torch
.
cat
(
dec_pnca_attn_h_list
[
layer_id
],
dim
=
1
)
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
class
PostNet
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
PostNet
,
self
).
__init__
()
self
.
filter_size
=
config
[
"postnet_filter_size"
]
self
.
fsmn_num_layers
=
config
[
"postnet_fsmn_num_layers"
]
self
.
num_memory_units
=
config
[
"postnet_num_memory_units"
]
self
.
ffn_inner_dim
=
config
[
"postnet_ffn_inner_dim"
]
self
.
dropout
=
config
[
"postnet_dropout"
]
self
.
shift
=
config
[
"postnet_shift"
]
self
.
lstm_units
=
config
[
"postnet_lstm_units"
]
self
.
num_mels
=
config
[
"num_mels"
]
self
.
fsmn
=
FsmnEncoderV2
(
self
.
filter_size
,
self
.
fsmn_num_layers
,
self
.
num_mels
,
self
.
num_memory_units
,
self
.
ffn_inner_dim
,
self
.
dropout
,
self
.
shift
,
)
self
.
lstm
=
nn
.
LSTM
(
self
.
num_memory_units
,
self
.
lstm_units
,
num_layers
=
1
,
batch_first
=
True
)
self
.
fc
=
nn
.
Linear
(
self
.
lstm_units
,
self
.
num_mels
)
def
forward
(
self
,
x
,
mask
=
None
):
postnet_fsmn_output
=
self
.
fsmn
(
x
,
mask
)
# The input can also be a packed variable length sequence,
# here we just omit it for simpliciy due to the mask and uni-directional lstm.
postnet_lstm_output
,
_
=
self
.
lstm
(
postnet_fsmn_output
)
mel_residual_output
=
self
.
fc
(
postnet_lstm_output
)
return
mel_residual_output
def
average_frame_feat
(
pitch
,
durs
):
durs_cums_ends
=
torch
.
cumsum
(
durs
,
dim
=
1
).
long
()
durs_cums_starts
=
F
.
pad
(
durs_cums_ends
[:,
:
-
1
],
(
1
,
0
))
pitch_nonzero_cums
=
F
.
pad
(
torch
.
cumsum
(
pitch
!=
0.0
,
dim
=
2
),
(
1
,
0
))
pitch_cums
=
F
.
pad
(
torch
.
cumsum
(
pitch
,
dim
=
2
),
(
1
,
0
))
bs
,
lengths
=
durs_cums_ends
.
size
()
n_formants
=
pitch
.
size
(
1
)
dcs
=
durs_cums_starts
[:,
None
,
:].
expand
(
bs
,
n_formants
,
lengths
)
dce
=
durs_cums_ends
[:,
None
,
:].
expand
(
bs
,
n_formants
,
lengths
)
pitch_sums
=
(
torch
.
gather
(
pitch_cums
,
2
,
dce
)
-
torch
.
gather
(
pitch_cums
,
2
,
dcs
)
).
float
()
pitch_nelems
=
(
torch
.
gather
(
pitch_nonzero_cums
,
2
,
dce
)
-
torch
.
gather
(
pitch_nonzero_cums
,
2
,
dcs
)
).
float
()
pitch_avg
=
torch
.
where
(
pitch_nelems
==
0.0
,
pitch_nelems
,
pitch_sums
/
pitch_nelems
)
return
pitch_avg
class
FP_Predictor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
FP_Predictor
,
self
).
__init__
()
self
.
w_1
=
nn
.
Conv1d
(
config
[
"encoder_projection_units"
],
config
[
"embedding_dim"
]
//
2
,
kernel_size
=
3
,
padding
=
1
,
)
self
.
w_2
=
nn
.
Conv1d
(
config
[
"embedding_dim"
]
//
2
,
config
[
"encoder_projection_units"
],
kernel_size
=
1
,
padding
=
0
,
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
[
"embedding_dim"
]
//
2
,
eps
=
1e-6
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
config
[
"encoder_projection_units"
],
eps
=
1e-6
)
self
.
dropout_inner
=
nn
.
Dropout
(
0.1
)
self
.
dropout
=
nn
.
Dropout
(
0.1
)
self
.
fc
=
nn
.
Linear
(
config
[
"encoder_projection_units"
],
4
)
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
1
,
2
)
x
=
F
.
relu
(
self
.
w_1
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
dropout_inner
(
self
.
layer_norm1
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
F
.
relu
(
self
.
w_2
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
dropout
(
self
.
layer_norm2
(
x
))
output
=
F
.
softmax
(
self
.
fc
(
x
),
dim
=
2
)
return
output
class
KanTtsSAMBERT
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
KanTtsSAMBERT
,
self
).
__init__
()
self
.
text_encoder
=
TextFftEncoder
(
config
)
self
.
se_enable
=
config
.
get
(
"SE"
,
False
)
if
not
self
.
se_enable
:
self
.
spk_tokenizer
=
nn
.
Embedding
(
config
[
"speaker"
],
config
[
"speaker_units"
])
self
.
emo_tokenizer
=
nn
.
Embedding
(
config
[
"emotion"
],
config
[
"emotion_units"
])
self
.
variance_adaptor
=
VarianceAdaptor
(
config
)
self
.
mel_decoder
=
MelPNCADecoder
(
config
)
self
.
mel_postnet
=
PostNet
(
config
)
self
.
MAS
=
False
if
config
.
get
(
"MAS"
,
False
):
self
.
MAS
=
True
self
.
align_attention
=
ConvAttention
(
n_mel_channels
=
config
[
"num_mels"
],
n_text_channels
=
config
[
"embedding_dim"
],
n_att_channels
=
config
[
"num_mels"
],
)
self
.
fp_enable
=
config
.
get
(
"FP"
,
False
)
if
self
.
fp_enable
:
self
.
FP_predictor
=
FP_Predictor
(
config
)
def
get_lfr_mask_from_lengths
(
self
,
lengths
,
max_len
):
batch_size
=
lengths
.
size
(
0
)
# padding according to the outputs_per_step
padded_lr_lengths
=
torch
.
zeros_like
(
lengths
)
for
i
in
range
(
batch_size
):
len_item
=
int
(
lengths
[
i
].
item
())
padding
=
self
.
mel_decoder
.
r
-
len_item
%
self
.
mel_decoder
.
r
if
padding
<
self
.
mel_decoder
.
r
:
padded_lr_lengths
[
i
]
=
(
len_item
+
padding
)
//
self
.
mel_decoder
.
r
else
:
padded_lr_lengths
[
i
]
=
len_item
//
self
.
mel_decoder
.
r
return
get_mask_from_lengths
(
padded_lr_lengths
,
max_len
=
max_len
//
self
.
mel_decoder
.
r
)
def
binarize_attention_parallel
(
self
,
attn
,
in_lens
,
out_lens
):
"""For training purposes only. Binarizes attention with MAS.
These will no longer recieve a gradient.
Args:
attn: B x 1 x max_mel_len x max_text_len
"""
with
torch
.
no_grad
():
attn_cpu
=
attn
.
data
.
cpu
().
numpy
()
attn_out
=
b_mas
(
attn_cpu
,
in_lens
.
cpu
().
numpy
(),
out_lens
.
cpu
().
numpy
(),
width
=
1
)
return
torch
.
from_numpy
(
attn_out
).
to
(
attn
.
get_device
())
def
insert_fp
(
self
,
text_hid
,
FP_p
,
fp_label
,
fp_dict
,
inputs_emotion
,
inputs_speaker
,
input_lengths
,
input_masks
,
):
en
,
_
,
_
=
self
.
text_encoder
(
fp_dict
[
1
],
return_attns
=
True
)
a
,
_
,
_
=
self
.
text_encoder
(
fp_dict
[
2
],
return_attns
=
True
)
e
,
_
,
_
=
self
.
text_encoder
(
fp_dict
[
3
],
return_attns
=
True
)
en
=
en
.
squeeze
()
a
=
a
.
squeeze
()
e
=
e
.
squeeze
()
max_len_ori
=
max
(
input_lengths
)
if
fp_label
is
None
:
input_masks_r
=
~
input_masks
fp_mask
=
(
FP_p
==
FP_p
.
max
(
dim
=
2
,
keepdim
=
True
)[
0
]).
to
(
dtype
=
torch
.
int32
)
fp_mask
=
fp_mask
[:,
:,
1
:]
*
input_masks_r
.
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
3
)
fp_number
=
torch
.
sum
(
torch
.
sum
(
fp_mask
,
dim
=
2
),
dim
=
1
)
else
:
fp_number
=
torch
.
sum
((
fp_label
>
0
),
dim
=
1
)
inter_lengths
=
input_lengths
+
3
*
fp_number
max_len
=
max
(
inter_lengths
)
delta
=
max_len
-
max_len_ori
if
delta
>
0
:
if
delta
>
text_hid
.
shape
[
1
]:
nrepeat
=
delta
//
text_hid
.
shape
[
1
]
bias
=
delta
%
text_hid
.
shape
[
1
]
text_hid
=
torch
.
cat
(
(
text_hid
,
text_hid
.
repeat
(
1
,
nrepeat
,
1
),
text_hid
[:,
:
bias
,
:]),
1
)
inputs_emotion
=
torch
.
cat
(
(
inputs_emotion
,
inputs_emotion
.
repeat
(
1
,
nrepeat
),
inputs_emotion
[:,
:
bias
],
),
1
,
)
inputs_speaker
=
torch
.
cat
(
(
inputs_speaker
,
inputs_speaker
.
repeat
(
1
,
nrepeat
),
inputs_speaker
[:,
:
bias
],
),
1
,
)
else
:
text_hid
=
torch
.
cat
((
text_hid
,
text_hid
[:,
:
delta
,
:]),
1
)
inputs_emotion
=
torch
.
cat
(
(
inputs_emotion
,
inputs_emotion
[:,
:
delta
]),
1
)
inputs_speaker
=
torch
.
cat
(
(
inputs_speaker
,
inputs_speaker
[:,
:
delta
]),
1
)
if
fp_label
is
None
:
for
i
in
range
(
fp_mask
.
shape
[
0
]):
for
j
in
range
(
fp_mask
.
shape
[
1
]
-
1
,
-
1
,
-
1
):
if
fp_mask
[
i
][
j
][
0
]
==
1
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
en
,
text_hid
[
i
][
j
:
-
3
]),
0
)
elif
fp_mask
[
i
][
j
][
1
]
==
1
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
a
,
text_hid
[
i
][
j
:
-
3
]),
0
)
elif
fp_mask
[
i
][
j
][
2
]
==
1
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
e
,
text_hid
[
i
][
j
:
-
3
]),
0
)
else
:
for
i
in
range
(
fp_label
.
shape
[
0
]):
for
j
in
range
(
fp_label
.
shape
[
1
]
-
1
,
-
1
,
-
1
):
if
fp_label
[
i
][
j
]
==
1
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
en
,
text_hid
[
i
][
j
:
-
3
]),
0
)
elif
fp_label
[
i
][
j
]
==
2
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
a
,
text_hid
[
i
][
j
:
-
3
]),
0
)
elif
fp_label
[
i
][
j
]
==
3
:
text_hid
[
i
]
=
torch
.
cat
(
(
text_hid
[
i
][:
j
],
e
,
text_hid
[
i
][
j
:
-
3
]),
0
)
return
text_hid
,
inputs_emotion
,
inputs_speaker
,
inter_lengths
def
forward
(
self
,
inputs_ling
,
inputs_emotion
,
inputs_speaker
,
input_lengths
,
output_lengths
=
None
,
mel_targets
=
None
,
duration_targets
=
None
,
pitch_targets
=
None
,
energy_targets
=
None
,
attn_priors
=
None
,
fp_label
=
None
,
):
batch_size
=
inputs_ling
.
size
(
0
)
is_training
=
mel_targets
is
not
None
input_masks
=
get_mask_from_lengths
(
input_lengths
,
max_len
=
inputs_ling
.
size
(
1
))
text_hid
,
enc_sla_attn_lst
,
ling_embedding
=
self
.
text_encoder
(
inputs_ling
,
input_masks
,
return_attns
=
True
)
inter_lengths
=
input_lengths
FP_p
=
None
if
self
.
fp_enable
:
FP_p
=
self
.
FP_predictor
(
text_hid
)
fp_dict
=
self
.
fp_dict
text_hid
,
inputs_emotion
,
inputs_speaker
,
inter_lengths
=
self
.
insert_fp
(
text_hid
,
FP_p
,
fp_label
,
fp_dict
,
inputs_emotion
,
inputs_speaker
,
input_lengths
,
input_masks
,
)
# Monotonic-Alignment-Search
if
self
.
MAS
and
is_training
:
attn_soft
,
attn_logprob
=
self
.
align_attention
(
mel_targets
.
permute
(
0
,
2
,
1
),
ling_embedding
.
permute
(
0
,
2
,
1
),
input_masks
,
attn_priors
,
)
attn_hard
=
self
.
binarize_attention_parallel
(
attn_soft
,
input_lengths
,
output_lengths
)
attn_hard_dur
=
attn_hard
.
sum
(
2
)[:,
0
,
:]
duration_targets
=
attn_hard_dur
assert
torch
.
all
(
torch
.
eq
(
duration_targets
.
sum
(
dim
=
1
),
output_lengths
))
pitch_targets
=
average_frame_feat
(
pitch_targets
.
unsqueeze
(
1
),
duration_targets
).
squeeze
(
1
)
energy_targets
=
average_frame_feat
(
energy_targets
.
unsqueeze
(
1
),
duration_targets
).
squeeze
(
1
)
# Padding the POS length to make it sum equal to max rounded output length
for
i
in
range
(
batch_size
):
len_item
=
int
(
output_lengths
[
i
].
item
())
padding
=
mel_targets
.
size
(
1
)
-
len_item
duration_targets
[
i
,
input_lengths
[
i
]]
=
padding
emo_hid
=
self
.
emo_tokenizer
(
inputs_emotion
)
spk_hid
=
inputs_speaker
if
self
.
se_enable
else
self
.
spk_tokenizer
(
inputs_speaker
)
inter_masks
=
get_mask_from_lengths
(
inter_lengths
,
max_len
=
text_hid
.
size
(
1
))
if
output_lengths
is
not
None
:
output_masks
=
get_mask_from_lengths
(
output_lengths
,
max_len
=
mel_targets
.
size
(
1
)
)
else
:
output_masks
=
None
(
LR_text_outputs
,
LR_emo_outputs
,
LR_spk_outputs
,
LR_length_rounded
,
log_duration_predictions
,
pitch_predictions
,
energy_predictions
,
)
=
self
.
variance_adaptor
(
text_hid
,
emo_hid
,
spk_hid
,
masks
=
inter_masks
,
output_masks
=
output_masks
,
duration_targets
=
duration_targets
,
pitch_targets
=
pitch_targets
,
energy_targets
=
energy_targets
,
)
if
output_lengths
is
not
None
:
lfr_masks
=
self
.
get_lfr_mask_from_lengths
(
output_lengths
,
max_len
=
LR_text_outputs
.
size
(
1
)
)
else
:
output_masks
=
get_mask_from_lengths
(
LR_length_rounded
,
max_len
=
LR_text_outputs
.
size
(
1
)
)
lfr_masks
=
None
# LFR with the factor of outputs_per_step
LFR_text_inputs
=
LR_text_outputs
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
mel_decoder
.
r
*
text_hid
.
shape
[
-
1
]
)
LFR_emo_inputs
=
LR_emo_outputs
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
mel_decoder
.
r
*
emo_hid
.
shape
[
-
1
]
)[:,
:,
:
emo_hid
.
shape
[
-
1
]]
LFR_spk_inputs
=
LR_spk_outputs
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
mel_decoder
.
r
*
spk_hid
.
shape
[
-
1
]
)[:,
:,
:
spk_hid
.
shape
[
-
1
]]
memory
=
torch
.
cat
([
LFR_text_inputs
,
LFR_spk_inputs
,
LFR_emo_inputs
],
dim
=-
1
)
if
duration_targets
is
not
None
:
x_band_width
=
int
(
duration_targets
.
float
().
masked_fill
(
inter_masks
,
0
).
max
()
/
self
.
mel_decoder
.
r
+
0.5
)
h_band_width
=
x_band_width
else
:
x_band_width
=
int
(
(
torch
.
exp
(
log_duration_predictions
)
-
1
).
max
()
/
self
.
mel_decoder
.
r
+
0.5
)
h_band_width
=
x_band_width
dec_outputs
,
pnca_x_attn_lst
,
pnca_h_attn_lst
=
self
.
mel_decoder
(
memory
,
x_band_width
,
h_band_width
,
target
=
mel_targets
,
mask
=
lfr_masks
,
return_attns
=
True
,
)
# De-LFR with the factor of outputs_per_step
dec_outputs
=
dec_outputs
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
mel_decoder
.
d_mel
)
if
output_masks
is
not
None
:
dec_outputs
=
dec_outputs
.
masked_fill
(
output_masks
.
unsqueeze
(
-
1
),
0
)
postnet_outputs
=
self
.
mel_postnet
(
dec_outputs
,
output_masks
)
+
dec_outputs
if
output_masks
is
not
None
:
postnet_outputs
=
postnet_outputs
.
masked_fill
(
output_masks
.
unsqueeze
(
-
1
),
0
)
res
=
{
"x_band_width"
:
x_band_width
,
"h_band_width"
:
h_band_width
,
"enc_slf_attn_lst"
:
enc_sla_attn_lst
,
"pnca_x_attn_lst"
:
pnca_x_attn_lst
,
"pnca_h_attn_lst"
:
pnca_h_attn_lst
,
"dec_outputs"
:
dec_outputs
,
"postnet_outputs"
:
postnet_outputs
,
"LR_length_rounded"
:
LR_length_rounded
,
"log_duration_predictions"
:
log_duration_predictions
,
"pitch_predictions"
:
pitch_predictions
,
"energy_predictions"
:
energy_predictions
,
"duration_targets"
:
duration_targets
,
"pitch_targets"
:
pitch_targets
,
"energy_targets"
:
energy_targets
,
"fp_predictions"
:
FP_p
,
"valid_inter_lengths"
:
inter_lengths
,
}
res
[
"LR_text_outputs"
]
=
LR_text_outputs
res
[
"LR_emo_outputs"
]
=
LR_emo_outputs
res
[
"LR_spk_outputs"
]
=
LR_spk_outputs
if
self
.
MAS
and
is_training
:
res
[
"attn_soft"
]
=
attn_soft
res
[
"attn_hard"
]
=
attn_hard
res
[
"attn_logprob"
]
=
attn_logprob
return
res
class
KanTtsTextsyBERT
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
KanTtsTextsyBERT
,
self
).
__init__
()
self
.
text_encoder
=
TextFftEncoder
(
config
)
delattr
(
self
.
text_encoder
,
"ling_proj"
)
self
.
fc
=
nn
.
Linear
(
self
.
text_encoder
.
d_model
,
config
[
"sy"
])
def
forward
(
self
,
inputs_ling
,
input_lengths
):
res
=
{}
input_masks
=
get_mask_from_lengths
(
input_lengths
,
max_len
=
inputs_ling
.
size
(
1
))
text_hid
,
enc_sla_attn_lst
=
self
.
text_encoder
(
inputs_ling
,
input_masks
,
return_attns
=
True
)
logits
=
self
.
fc
(
text_hid
)
res
[
"logits"
]
=
logits
res
[
"enc_slf_attn_lst"
]
=
enc_sla_attn_lst
return
res
kantts/models/sambert/kantts_sambert_divide.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
kantts.models.sambert
import
FFTBlock
,
PNCABlock
,
Prenet
from
kantts.models.sambert.positions
import
(
SinusoidalPositionEncoder
,
DurSinusoidalPositionEncoder
,
)
from
kantts.models.sambert.adaptors
import
(
LengthRegulator
,
VarFsmnRnnNARPredictor
,
VarRnnARPredictor
,
)
from
kantts.models.sambert.fsmn
import
FsmnEncoderV2
from
kantts.models.sambert.alignment
import
b_mas
from
kantts.models.sambert.attention
import
ConvAttention
from
kantts.models.utils
import
get_mask_from_lengths
class
SelfAttentionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
n_layer
,
d_in
,
d_model
,
n_head
,
d_head
,
d_inner
,
dropout
,
dropout_att
,
dropout_relu
,
position_encoder
,
):
super
(
SelfAttentionEncoder
,
self
).
__init__
()
self
.
d_in
=
d_in
self
.
d_model
=
d_model
self
.
dropout
=
dropout
d_in_lst
=
[
d_in
]
+
[
d_model
]
*
(
n_layer
-
1
)
self
.
fft
=
nn
.
ModuleList
(
[
FFTBlock
(
d
,
d_model
,
n_head
,
d_head
,
d_inner
,
(
3
,
1
),
dropout
,
dropout_att
,
dropout_relu
,
)
for
d
in
d_in_lst
]
)
self
.
ln
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
position_enc
=
position_encoder
def
forward
(
self
,
input
,
mask
=
None
,
return_attns
=
False
):
input
*=
self
.
d_model
**
0.5
if
isinstance
(
self
.
position_enc
,
SinusoidalPositionEncoder
):
input
=
self
.
position_enc
(
input
)
else
:
raise
NotImplementedError
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
enc_slf_attn_list
=
[]
max_len
=
input
.
size
(
1
)
if
mask
is
not
None
:
slf_attn_mask
=
mask
.
unsqueeze
(
1
).
expand
(
-
1
,
max_len
,
-
1
)
else
:
slf_attn_mask
=
None
enc_output
=
input
for
id
,
layer
in
enumerate
(
self
.
fft
):
enc_output
,
enc_slf_attn
=
layer
(
enc_output
,
mask
=
mask
,
slf_attn_mask
=
slf_attn_mask
)
if
return_attns
:
enc_slf_attn_list
+=
[
enc_slf_attn
]
enc_output
=
self
.
ln
(
enc_output
)
return
enc_output
,
enc_slf_attn_list
class
HybridAttentionDecoder
(
nn
.
Module
):
def
__init__
(
self
,
d_in
,
prenet_units
,
n_layer
,
d_model
,
d_mem
,
n_head
,
d_head
,
d_inner
,
dropout
,
dropout_att
,
dropout_relu
,
d_out
,
):
super
(
HybridAttentionDecoder
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
dropout
=
dropout
self
.
prenet
=
Prenet
(
d_in
,
prenet_units
,
d_model
)
self
.
dec_in_proj
=
nn
.
Linear
(
d_model
+
d_mem
,
d_model
)
self
.
pnca
=
nn
.
ModuleList
(
[
PNCABlock
(
d_model
,
d_mem
,
n_head
,
d_head
,
d_inner
,
(
1
,
1
),
dropout
,
dropout_att
,
dropout_relu
,
)
for
_
in
range
(
n_layer
)
]
)
self
.
ln
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
dec_out_proj
=
nn
.
Linear
(
d_model
,
d_out
)
def
reset_state
(
self
):
for
layer
in
self
.
pnca
:
layer
.
reset_state
()
def
get_pnca_attn_mask
(
self
,
device
,
max_len
,
x_band_width
,
h_band_width
,
masks
=
None
):
if
masks
is
not
None
:
pnca_attn_mask
=
masks
.
unsqueeze
(
1
).
expand
(
-
1
,
max_len
,
-
1
)
else
:
pnca_attn_mask
=
None
range_
=
torch
.
arange
(
max_len
).
to
(
device
)
x_start
=
torch
.
clamp_min
(
range_
-
x_band_width
,
0
)[
None
,
None
,
:]
x_end
=
(
range_
+
1
)[
None
,
None
,
:]
h_start
=
range_
[
None
,
None
,
:]
h_end
=
torch
.
clamp_max
(
range_
+
h_band_width
+
1
,
max_len
+
1
)[
None
,
None
,
:]
pnca_x_attn_mask
=
~
(
(
x_start
<=
range_
[
None
,
:,
None
])
&
(
x_end
>
range_
[
None
,
:,
None
])
).
transpose
(
1
,
2
)
pnca_h_attn_mask
=
~
(
(
h_start
<=
range_
[
None
,
:,
None
])
&
(
h_end
>
range_
[
None
,
:,
None
])
).
transpose
(
1
,
2
)
if
pnca_attn_mask
is
not
None
:
pnca_x_attn_mask
=
pnca_x_attn_mask
|
pnca_attn_mask
pnca_h_attn_mask
=
pnca_h_attn_mask
|
pnca_attn_mask
pnca_x_attn_mask
=
pnca_x_attn_mask
.
masked_fill
(
pnca_attn_mask
.
transpose
(
1
,
2
),
False
)
pnca_h_attn_mask
=
pnca_h_attn_mask
.
masked_fill
(
pnca_attn_mask
.
transpose
(
1
,
2
),
False
)
return
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
# must call reset_state before
def
forward
(
self
,
input
,
memory
,
x_band_width
,
h_band_width
,
masks
=
None
,
return_attns
=
False
):
input
=
self
.
prenet
(
input
)
input
=
torch
.
cat
([
memory
,
input
],
dim
=-
1
)
input
=
self
.
dec_in_proj
(
input
)
if
masks
is
not
None
:
input
=
input
.
masked_fill
(
masks
.
unsqueeze
(
-
1
),
0
)
input
*=
self
.
d_model
**
0.5
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
max_len
=
input
.
size
(
1
)
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
=
self
.
get_pnca_attn_mask
(
input
.
device
,
max_len
,
x_band_width
,
h_band_width
,
masks
)
dec_pnca_attn_x_list
=
[]
dec_pnca_attn_h_list
=
[]
dec_output
=
input
for
id
,
layer
in
enumerate
(
self
.
pnca
):
dec_output
,
dec_pnca_attn_x
,
dec_pnca_attn_h
=
layer
(
dec_output
,
memory
,
masks
=
masks
,
pnca_x_attn_mask
=
pnca_x_attn_mask
,
pnca_h_attn_mask
=
pnca_h_attn_mask
,
)
if
return_attns
:
dec_pnca_attn_x_list
+=
[
dec_pnca_attn_x
]
dec_pnca_attn_h_list
+=
[
dec_pnca_attn_h
]
dec_output
=
self
.
ln
(
dec_output
)
dec_output
=
self
.
dec_out_proj
(
dec_output
)
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
# must call reset_state before when step == 0
def
infer
(
self
,
step
,
input
,
memory
,
x_band_width
,
h_band_width
,
masks
=
None
,
return_attns
=
False
,
):
max_len
=
memory
.
size
(
1
)
input
=
self
.
prenet
(
input
)
input
=
torch
.
cat
([
memory
[:,
step
:
step
+
1
,
:],
input
],
dim
=-
1
)
input
=
self
.
dec_in_proj
(
input
)
input
*=
self
.
d_model
**
0.5
input
=
F
.
dropout
(
input
,
p
=
self
.
dropout
,
training
=
self
.
training
)
pnca_attn_mask
,
pnca_x_attn_mask
,
pnca_h_attn_mask
=
self
.
get_pnca_attn_mask
(
input
.
device
,
max_len
,
x_band_width
,
h_band_width
,
masks
)
dec_pnca_attn_x_list
=
[]
dec_pnca_attn_h_list
=
[]
dec_output
=
input
for
id
,
layer
in
enumerate
(
self
.
pnca
):
if
masks
is
not
None
:
mask_step
=
masks
[:,
step
:
step
+
1
]
else
:
mask_step
=
None
dec_output
,
dec_pnca_attn_x
,
dec_pnca_attn_h
=
layer
(
dec_output
,
memory
,
mask
=
mask_step
,
pnca_x_attn_mask
=
pnca_x_attn_mask
[:,
step
:
step
+
1
,
:
(
step
+
1
)],
pnca_h_attn_mask
=
pnca_h_attn_mask
[:,
step
:
step
+
1
,
:],
)
if
return_attns
:
dec_pnca_attn_x_list
+=
[
dec_pnca_attn_x
]
dec_pnca_attn_h_list
+=
[
dec_pnca_attn_h
]
dec_output
=
self
.
ln
(
dec_output
)
dec_output
=
self
.
dec_out_proj
(
dec_output
)
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
class
TextFftEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TextFftEncoder
,
self
).
__init__
()
d_emb
=
config
[
"embedding_dim"
]
self
.
using_byte
=
False
if
config
.
get
(
"using_byte"
,
False
):
self
.
using_byte
=
True
nb_ling_byte_index
=
config
[
"byte_index"
]
self
.
byte_index_emb
=
nn
.
Embedding
(
nb_ling_byte_index
,
d_emb
)
else
:
# linguistic unit lookup table
nb_ling_sy
=
config
[
"sy"
]
nb_ling_tone
=
config
[
"tone"
]
nb_ling_syllable_flag
=
config
[
"syllable_flag"
]
nb_ling_ws
=
config
[
"word_segment"
]
self
.
sy_emb
=
nn
.
Embedding
(
nb_ling_sy
,
d_emb
)
self
.
tone_emb
=
nn
.
Embedding
(
nb_ling_tone
,
d_emb
)
self
.
syllable_flag_emb
=
nn
.
Embedding
(
nb_ling_syllable_flag
,
d_emb
)
self
.
ws_emb
=
nn
.
Embedding
(
nb_ling_ws
,
d_emb
)
max_len
=
config
[
"max_len"
]
nb_layers
=
config
[
"encoder_num_layers"
]
nb_heads
=
config
[
"encoder_num_heads"
]
d_model
=
config
[
"encoder_num_units"
]
d_head
=
d_model
//
nb_heads
d_inner
=
config
[
"encoder_ffn_inner_dim"
]
dropout
=
config
[
"encoder_dropout"
]
dropout_attn
=
config
[
"encoder_attention_dropout"
]
dropout_relu
=
config
[
"encoder_relu_dropout"
]
d_proj
=
config
[
"encoder_projection_units"
]
self
.
d_model
=
d_model
position_enc
=
SinusoidalPositionEncoder
(
max_len
,
d_emb
)
self
.
ling_enc
=
SelfAttentionEncoder
(
nb_layers
,
d_emb
,
d_model
,
nb_heads
,
d_head
,
d_inner
,
dropout
,
dropout_attn
,
dropout_relu
,
position_enc
,
)
self
.
ling_proj
=
nn
.
Linear
(
d_model
,
d_proj
,
bias
=
False
)
def
forward
(
self
,
inputs_ling
,
masks
=
None
,
return_attns
=
False
):
# Parse inputs_ling_seq
if
self
.
using_byte
:
inputs_byte_index
=
inputs_ling
[:,
:,
0
]
byte_index_embedding
=
self
.
byte_index_emb
(
inputs_byte_index
)
ling_embedding
=
byte_index_embedding
else
:
inputs_sy
=
inputs_ling
[:,
:,
0
]
inputs_tone
=
inputs_ling
[:,
:,
1
]
inputs_syllable_flag
=
inputs_ling
[:,
:,
2
]
inputs_ws
=
inputs_ling
[:,
:,
3
]
# Lookup table
sy_embedding
=
self
.
sy_emb
(
inputs_sy
)
tone_embedding
=
self
.
tone_emb
(
inputs_tone
)
syllable_flag_embedding
=
self
.
syllable_flag_emb
(
inputs_syllable_flag
)
ws_embedding
=
self
.
ws_emb
(
inputs_ws
)
ling_embedding
=
(
sy_embedding
+
tone_embedding
+
syllable_flag_embedding
+
ws_embedding
)
enc_output
,
enc_slf_attn_lst
=
self
.
ling_enc
(
ling_embedding
,
masks
,
return_attns
)
if
hasattr
(
self
,
"ling_proj"
):
enc_output
=
self
.
ling_proj
(
enc_output
)
return
enc_output
,
enc_slf_attn_lst
,
ling_embedding
class
TextEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TextEncoder
,
self
).
__init__
()
self
.
text_encoder
=
TextFftEncoder
(
config
)
self
.
se_enable
=
config
.
get
(
"SE"
,
False
)
if
not
self
.
se_enable
:
self
.
spk_tokenizer
=
nn
.
Embedding
(
config
[
"speaker"
],
config
[
"speaker_units"
])
self
.
emo_tokenizer
=
nn
.
Embedding
(
config
[
"emotion"
],
config
[
"emotion_units"
])
# self.variance_adaptor = VarianceAdaptor(config)
# self.mel_decoder = MelPNCADecoder(config)
# self.mel_postnet = PostNet(config)
self
.
MAS
=
False
if
config
.
get
(
"MAS"
,
False
):
self
.
MAS
=
True
self
.
align_attention
=
ConvAttention
(
n_mel_channels
=
config
[
"num_mels"
],
n_text_channels
=
config
[
"embedding_dim"
],
n_att_channels
=
config
[
"num_mels"
],
)
self
.
fp_enable
=
config
.
get
(
"FP"
,
False
)
if
self
.
fp_enable
:
self
.
FP_predictor
=
FP_Predictor
(
config
)
def
forward
(
self
,
inputs_ling
,
inputs_emotion
,
inputs_speaker
,
inputs_ling_masks
=
None
,
return_attns
=
False
):
text_hid
,
enc_sla_attn_lst
,
ling_embedding
=
self
.
text_encoder
(
inputs_ling
,
inputs_ling_masks
,
return_attns
)
emo_hid
=
self
.
emo_tokenizer
(
inputs_emotion
)
spk_hid
=
inputs_speaker
if
self
.
se_enable
else
self
.
spk_tokenizer
(
inputs_speaker
)
if
return_attns
:
return
text_hid
,
enc_sla_attn_lst
,
ling_embedding
,
emo_hid
,
spk_hid
else
:
return
text_hid
,
ling_embedding
,
emo_hid
,
spk_hid
class
VarianceAdaptor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
VarianceAdaptor
,
self
).
__init__
()
input_dim
=
(
config
[
"encoder_projection_units"
]
+
config
[
"emotion_units"
]
+
config
[
"speaker_units"
]
)
filter_size
=
config
[
"predictor_filter_size"
]
fsmn_num_layers
=
config
[
"predictor_fsmn_num_layers"
]
num_memory_units
=
config
[
"predictor_num_memory_units"
]
ffn_inner_dim
=
config
[
"predictor_ffn_inner_dim"
]
dropout
=
config
[
"predictor_dropout"
]
shift
=
config
[
"predictor_shift"
]
lstm_units
=
config
[
"predictor_lstm_units"
]
dur_pred_prenet_units
=
config
[
"dur_pred_prenet_units"
]
dur_pred_lstm_units
=
config
[
"dur_pred_lstm_units"
]
self
.
pitch_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
energy_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
duration_predictor
=
VarRnnARPredictor
(
input_dim
,
dur_pred_prenet_units
,
dur_pred_lstm_units
)
self
.
length_regulator
=
LengthRegulator
(
config
[
"outputs_per_step"
])
self
.
dur_position_encoder
=
DurSinusoidalPositionEncoder
(
config
[
"encoder_projection_units"
],
config
[
"outputs_per_step"
]
)
self
.
pitch_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
self
.
energy_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
def
forward
(
self
,
inputs_text_embedding
,
inputs_emo_embedding
,
inputs_spk_embedding
,
# [1,20,192]
masks
=
None
,
output_masks
=
None
,
duration_targets
=
None
,
pitch_targets
=
None
,
energy_targets
=
None
,
):
batch_size
=
inputs_text_embedding
.
size
(
0
)
variance_predictor_inputs
=
torch
.
cat
(
[
inputs_text_embedding
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
)
pitch_predictions
=
self
.
pitch_predictor
(
variance_predictor_inputs
,
masks
)
energy_predictions
=
self
.
energy_predictor
(
variance_predictor_inputs
,
masks
)
if
pitch_targets
is
not
None
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_predictions
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
if
energy_targets
is
not
None
:
energy_embeddings
=
self
.
energy_emb
(
energy_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
energy_embeddings
=
self
.
energy_emb
(
energy_predictions
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
inputs_text_embedding_aug
=
(
inputs_text_embedding
+
pitch_embeddings
+
energy_embeddings
)
duration_predictor_cond
=
torch
.
cat
(
[
inputs_text_embedding_aug
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
,
)
if
duration_targets
is
not
None
:
duration_predictor_go_frame
=
torch
.
zeros
(
batch_size
,
1
).
to
(
inputs_text_embedding
.
device
)
duration_predictor_input
=
torch
.
cat
(
[
duration_predictor_go_frame
,
duration_targets
[:,
:
-
1
].
float
()],
dim
=-
1
)
duration_predictor_input
=
torch
.
log
(
duration_predictor_input
+
1
)
log_duration_predictions
,
_
=
self
.
duration_predictor
(
duration_predictor_input
.
unsqueeze
(
-
1
),
duration_predictor_cond
,
masks
=
masks
,
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
else
:
log_duration_predictions
=
self
.
duration_predictor
.
infer
(
duration_predictor_cond
,
masks
=
masks
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
if
duration_targets
is
not
None
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_targets
,
masks
=
output_masks
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_targets
,
masks
=
output_masks
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_targets
,
masks
=
output_masks
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_targets
,
masks
=
output_masks
)
else
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_predictions
,
masks
=
output_masks
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_predictions
,
masks
=
output_masks
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_predictions
,
masks
=
output_masks
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_predictions
,
masks
=
output_masks
)
LR_text_outputs
=
LR_text_outputs
+
LR_position_embeddings
return
(
LR_text_outputs
,
LR_emo_outputs
,
LR_spk_outputs
,
# [1,153,192]
LR_length_rounded
,
log_duration_predictions
,
pitch_predictions
,
energy_predictions
,
)
class
VarianceAdaptor2
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
VarianceAdaptor2
,
self
).
__init__
()
input_dim
=
(
config
[
"encoder_projection_units"
]
+
config
[
"emotion_units"
]
+
config
[
"speaker_units"
]
)
filter_size
=
config
[
"predictor_filter_size"
]
fsmn_num_layers
=
config
[
"predictor_fsmn_num_layers"
]
num_memory_units
=
config
[
"predictor_num_memory_units"
]
ffn_inner_dim
=
config
[
"predictor_ffn_inner_dim"
]
dropout
=
config
[
"predictor_dropout"
]
shift
=
config
[
"predictor_shift"
]
lstm_units
=
config
[
"predictor_lstm_units"
]
dur_pred_prenet_units
=
config
[
"dur_pred_prenet_units"
]
dur_pred_lstm_units
=
config
[
"dur_pred_lstm_units"
]
self
.
pitch_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
energy_predictor
=
VarFsmnRnnNARPredictor
(
input_dim
,
filter_size
,
fsmn_num_layers
,
num_memory_units
,
ffn_inner_dim
,
dropout
,
shift
,
lstm_units
,
)
self
.
duration_predictor
=
VarRnnARPredictor
(
input_dim
,
dur_pred_prenet_units
,
dur_pred_lstm_units
)
self
.
length_regulator
=
LengthRegulator
(
config
[
"outputs_per_step"
])
self
.
dur_position_encoder
=
DurSinusoidalPositionEncoder
(
config
[
"encoder_projection_units"
],
config
[
"outputs_per_step"
]
)
self
.
pitch_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
self
.
energy_emb
=
nn
.
Conv1d
(
1
,
config
[
"encoder_projection_units"
],
kernel_size
=
9
,
padding
=
4
)
def
forward
(
self
,
inputs_text_embedding
,
inputs_emo_embedding
,
inputs_spk_embedding
,
# [1,20,192]
scale
=
1.0
,
masks
=
None
,
output_masks
=
None
,
duration_targets
=
None
,
pitch_targets
=
None
,
energy_targets
=
None
,
):
batch_size
=
inputs_text_embedding
.
size
(
0
)
variance_predictor_inputs
=
torch
.
cat
(
[
inputs_text_embedding
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
)
pitch_predictions
=
self
.
pitch_predictor
(
variance_predictor_inputs
,
masks
)
energy_predictions
=
self
.
energy_predictor
(
variance_predictor_inputs
,
masks
)
if
pitch_targets
is
not
None
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
pitch_embeddings
=
self
.
pitch_emb
(
pitch_predictions
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
if
energy_targets
is
not
None
:
energy_embeddings
=
self
.
energy_emb
(
energy_targets
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
else
:
energy_embeddings
=
self
.
energy_emb
(
energy_predictions
.
unsqueeze
(
1
)).
transpose
(
1
,
2
)
inputs_text_embedding_aug
=
(
inputs_text_embedding
+
pitch_embeddings
+
energy_embeddings
)
duration_predictor_cond
=
torch
.
cat
(
[
inputs_text_embedding_aug
,
inputs_spk_embedding
,
inputs_emo_embedding
],
dim
=-
1
,
)
if
duration_targets
is
not
None
:
duration_predictor_go_frame
=
torch
.
zeros
(
batch_size
,
1
).
to
(
inputs_text_embedding
.
device
)
duration_predictor_input
=
torch
.
cat
(
[
duration_predictor_go_frame
,
duration_targets
[:,
:
-
1
].
float
()],
dim
=-
1
)
duration_predictor_input
=
torch
.
log
(
duration_predictor_input
+
1
)
log_duration_predictions
,
_
=
self
.
duration_predictor
(
duration_predictor_input
.
unsqueeze
(
-
1
),
duration_predictor_cond
,
masks
=
masks
,
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
else
:
log_duration_predictions
=
self
.
duration_predictor
.
infer
(
duration_predictor_cond
,
masks
=
masks
)
duration_predictions
=
torch
.
exp
(
log_duration_predictions
)
-
1
if
duration_targets
is
not
None
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_targets
*
scale
,
masks
=
output_masks
# *scale
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_targets
,
masks
=
output_masks
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_targets
*
scale
,
masks
=
output_masks
# *scale
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_targets
*
scale
,
masks
=
output_masks
# *scale
)
else
:
LR_text_outputs
,
LR_length_rounded
=
self
.
length_regulator
(
inputs_text_embedding_aug
,
duration_predictions
*
scale
,
masks
=
output_masks
# *scale
)
LR_position_embeddings
=
self
.
dur_position_encoder
(
duration_predictions
*
scale
,
masks
=
output_masks
# *target_rate
)
LR_emo_outputs
,
_
=
self
.
length_regulator
(
inputs_emo_embedding
,
duration_predictions
*
scale
,
masks
=
output_masks
# *scale
)
LR_spk_outputs
,
_
=
self
.
length_regulator
(
inputs_spk_embedding
,
duration_predictions
*
scale
,
masks
=
output_masks
# *scale
)
LR_text_outputs
=
LR_text_outputs
+
LR_position_embeddings
return
(
LR_text_outputs
,
LR_emo_outputs
,
LR_spk_outputs
,
# [1,153,192]
LR_length_rounded
,
log_duration_predictions
,
pitch_predictions
,
energy_predictions
,
)
class
MelPNCADecoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
MelPNCADecoder
,
self
).
__init__
()
prenet_units
=
config
[
"decoder_prenet_units"
]
nb_layers
=
config
[
"decoder_num_layers"
]
nb_heads
=
config
[
"decoder_num_heads"
]
d_model
=
config
[
"decoder_num_units"
]
d_head
=
d_model
//
nb_heads
d_inner
=
config
[
"decoder_ffn_inner_dim"
]
dropout
=
config
[
"decoder_dropout"
]
dropout_attn
=
config
[
"decoder_attention_dropout"
]
dropout_relu
=
config
[
"decoder_relu_dropout"
]
outputs_per_step
=
config
[
"outputs_per_step"
]
d_mem
=
(
config
[
"encoder_projection_units"
]
*
outputs_per_step
+
config
[
"emotion_units"
]
+
config
[
"speaker_units"
]
)
d_mel
=
config
[
"num_mels"
]
self
.
d_mel
=
d_mel
self
.
r
=
outputs_per_step
self
.
nb_layers
=
nb_layers
self
.
mel_dec
=
HybridAttentionDecoder
(
d_mel
,
prenet_units
,
nb_layers
,
d_model
,
d_mem
,
nb_heads
,
d_head
,
d_inner
,
dropout
,
dropout_attn
,
dropout_relu
,
d_mel
*
outputs_per_step
,
)
def
forward
(
self
,
memory
,
x_band_width
,
h_band_width
,
target
=
None
,
masks
=
None
,
return_attns
=
False
,
):
batch_size
=
memory
.
size
(
0
)
go_frame
=
torch
.
zeros
((
batch_size
,
1
,
self
.
d_mel
)).
to
(
memory
.
device
)
if
target
is
not
None
:
self
.
mel_dec
.
reset_state
()
input
=
target
[:,
self
.
r
-
1
::
self
.
r
,
:]
input
=
torch
.
cat
([
go_frame
,
input
],
dim
=
1
)[:,
:
-
1
,
:]
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
=
self
.
mel_dec
(
input
,
memory
,
x_band_width
,
h_band_width
,
masks
=
masks
,
return_attns
=
return_attns
,
)
else
:
dec_output
=
[]
dec_pnca_attn_x_list
=
[[]
for
_
in
range
(
self
.
nb_layers
)]
dec_pnca_attn_h_list
=
[[]
for
_
in
range
(
self
.
nb_layers
)]
self
.
mel_dec
.
reset_state
()
input
=
go_frame
for
step
in
range
(
memory
.
size
(
1
)):
(
dec_output_step
,
dec_pnca_attn_x_step
,
dec_pnca_attn_h_step
,
)
=
self
.
mel_dec
.
infer
(
step
,
input
,
memory
,
x_band_width
,
h_band_width
,
masks
=
masks
,
return_attns
=
return_attns
,
)
input
=
dec_output_step
[:,
:,
-
self
.
d_mel
:]
dec_output
.
append
(
dec_output_step
)
for
layer_id
,
(
pnca_x_attn
,
pnca_h_attn
)
in
enumerate
(
zip
(
dec_pnca_attn_x_step
,
dec_pnca_attn_h_step
)
):
left
=
memory
.
size
(
1
)
-
pnca_x_attn
.
size
(
-
1
)
if
left
>
0
:
padding
=
torch
.
zeros
((
pnca_x_attn
.
size
(
0
),
1
,
left
)).
to
(
pnca_x_attn
)
pnca_x_attn
=
torch
.
cat
([
pnca_x_attn
,
padding
],
dim
=-
1
)
dec_pnca_attn_x_list
[
layer_id
].
append
(
pnca_x_attn
)
dec_pnca_attn_h_list
[
layer_id
].
append
(
pnca_h_attn
)
dec_output
=
torch
.
cat
(
dec_output
,
dim
=
1
)
if
return_attns
:
for
layer_id
in
range
(
self
.
nb_layers
):
dec_pnca_attn_x_list
[
layer_id
]
=
torch
.
cat
(
dec_pnca_attn_x_list
[
layer_id
],
dim
=
1
)
dec_pnca_attn_h_list
[
layer_id
]
=
torch
.
cat
(
dec_pnca_attn_h_list
[
layer_id
],
dim
=
1
)
if
return_attns
:
return
dec_output
,
dec_pnca_attn_x_list
,
dec_pnca_attn_h_list
else
:
return
dec_output
class
PostNet
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
PostNet
,
self
).
__init__
()
self
.
filter_size
=
config
[
"postnet_filter_size"
]
self
.
fsmn_num_layers
=
config
[
"postnet_fsmn_num_layers"
]
self
.
num_memory_units
=
config
[
"postnet_num_memory_units"
]
self
.
ffn_inner_dim
=
config
[
"postnet_ffn_inner_dim"
]
self
.
dropout
=
config
[
"postnet_dropout"
]
self
.
shift
=
config
[
"postnet_shift"
]
self
.
lstm_units
=
config
[
"postnet_lstm_units"
]
self
.
num_mels
=
config
[
"num_mels"
]
self
.
fsmn
=
FsmnEncoderV2
(
self
.
filter_size
,
self
.
fsmn_num_layers
,
self
.
num_mels
,
self
.
num_memory_units
,
self
.
ffn_inner_dim
,
self
.
dropout
,
self
.
shift
,
)
self
.
lstm
=
nn
.
LSTM
(
self
.
num_memory_units
,
self
.
lstm_units
,
num_layers
=
1
,
batch_first
=
True
)
self
.
fc
=
nn
.
Linear
(
self
.
lstm_units
,
self
.
num_mels
)
def
forward
(
self
,
x
,
mask
=
None
):
postnet_fsmn_output
=
self
.
fsmn
(
x
,
mask
)
# The input can also be a packed variable length sequence,
# here we just omit it for simpliciy due to the mask and uni-directional lstm.
postnet_lstm_output
,
_
=
self
.
lstm
(
postnet_fsmn_output
)
mel_residual_output
=
self
.
fc
(
postnet_lstm_output
)
return
mel_residual_output
class
FP_Predictor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
FP_Predictor
,
self
).
__init__
()
self
.
w_1
=
nn
.
Conv1d
(
config
[
"encoder_projection_units"
],
config
[
"embedding_dim"
]
//
2
,
kernel_size
=
3
,
padding
=
1
,
)
self
.
w_2
=
nn
.
Conv1d
(
config
[
"embedding_dim"
]
//
2
,
config
[
"encoder_projection_units"
],
kernel_size
=
1
,
padding
=
0
,
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
[
"embedding_dim"
]
//
2
,
eps
=
1e-6
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
config
[
"encoder_projection_units"
],
eps
=
1e-6
)
self
.
dropout_inner
=
nn
.
Dropout
(
0.1
)
self
.
dropout
=
nn
.
Dropout
(
0.1
)
self
.
fc
=
nn
.
Linear
(
config
[
"encoder_projection_units"
],
4
)
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
1
,
2
)
x
=
F
.
relu
(
self
.
w_1
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
dropout_inner
(
self
.
layer_norm1
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
F
.
relu
(
self
.
w_2
(
x
))
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
dropout
(
self
.
layer_norm2
(
x
))
output
=
F
.
softmax
(
self
.
fc
(
x
),
dim
=
2
)
return
output
\ No newline at end of file
kantts/models/sambert/positions.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
class
SinusoidalPositionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
max_len
,
depth
):
super
(
SinusoidalPositionEncoder
,
self
).
__init__
()
self
.
max_len
=
max_len
self
.
depth
=
depth
self
.
position_enc
=
nn
.
Parameter
(
self
.
get_sinusoid_encoding_table
(
max_len
,
depth
).
unsqueeze
(
0
),
requires_grad
=
False
,
)
def
forward
(
self
,
input
):
bz_in
,
len_in
,
_
=
input
.
size
()
if
len_in
>
self
.
max_len
:
self
.
max_len
=
len_in
self
.
position_enc
.
data
=
(
self
.
get_sinusoid_encoding_table
(
self
.
max_len
,
self
.
depth
)
.
unsqueeze
(
0
)
.
to
(
input
.
device
)
)
output
=
input
+
self
.
position_enc
[:,
:
len_in
,
:].
expand
(
bz_in
,
-
1
,
-
1
)
return
output
@
staticmethod
def
get_sinusoid_encoding_table
(
n_position
,
d_hid
,
padding_idx
=
None
):
""" Sinusoid position encoding table """
def
cal_angle
(
position
,
hid_idx
):
return
position
/
np
.
power
(
10000
,
hid_idx
/
float
(
d_hid
/
2
-
1
))
def
get_posi_angle_vec
(
position
):
return
[
cal_angle
(
position
,
hid_j
)
for
hid_j
in
range
(
d_hid
//
2
)]
scaled_time_table
=
np
.
array
(
[
get_posi_angle_vec
(
pos_i
+
1
)
for
pos_i
in
range
(
n_position
)]
)
sinusoid_table
=
np
.
zeros
((
n_position
,
d_hid
))
sinusoid_table
[:,
:
d_hid
//
2
]
=
np
.
sin
(
scaled_time_table
)
sinusoid_table
[:,
d_hid
//
2
:]
=
np
.
cos
(
scaled_time_table
)
if
padding_idx
is
not
None
:
# zero vector for padding dimension
sinusoid_table
[
padding_idx
]
=
0.0
return
torch
.
FloatTensor
(
sinusoid_table
)
class
DurSinusoidalPositionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
depth
,
outputs_per_step
):
super
(
DurSinusoidalPositionEncoder
,
self
).
__init__
()
self
.
depth
=
depth
self
.
outputs_per_step
=
outputs_per_step
inv_timescales
=
[
np
.
power
(
10000
,
2
*
(
hid_idx
//
2
)
/
depth
)
for
hid_idx
in
range
(
depth
)
]
self
.
inv_timescales
=
nn
.
Parameter
(
torch
.
FloatTensor
(
inv_timescales
),
requires_grad
=
False
)
def
forward
(
self
,
durations
,
masks
=
None
):
reps
=
(
durations
+
0.5
).
long
()
output_lens
=
reps
.
sum
(
dim
=
1
)
max_len
=
output_lens
.
max
()
reps_cumsum
=
torch
.
cumsum
(
F
.
pad
(
reps
.
float
(),
(
1
,
0
,
0
,
0
),
value
=
0.0
),
dim
=
1
)[
:,
None
,
:
]
range_
=
torch
.
arange
(
max_len
).
to
(
durations
.
device
)[
None
,
:,
None
]
mult
=
(
reps_cumsum
[:,
:,
:
-
1
]
<=
range_
)
&
(
reps_cumsum
[:,
:,
1
:]
>
range_
)
mult
=
mult
.
float
()
offsets
=
torch
.
matmul
(
mult
,
reps_cumsum
[:,
0
,
:
-
1
].
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
dur_pos
=
range_
[:,
:,
0
]
-
offsets
+
1
if
masks
is
not
None
:
assert
masks
.
size
(
1
)
==
dur_pos
.
size
(
1
)
dur_pos
=
dur_pos
.
masked_fill
(
masks
,
0.0
)
seq_len
=
dur_pos
.
size
(
1
)
padding
=
self
.
outputs_per_step
-
int
(
seq_len
)
%
self
.
outputs_per_step
if
padding
<
self
.
outputs_per_step
:
dur_pos
=
F
.
pad
(
dur_pos
,
(
0
,
padding
,
0
,
0
),
value
=
0.0
)
position_embedding
=
dur_pos
[:,
:,
None
]
/
self
.
inv_timescales
[
None
,
None
,
:]
position_embedding
[:,
:,
0
::
2
]
=
torch
.
sin
(
position_embedding
[:,
:,
0
::
2
])
position_embedding
[:,
:,
1
::
2
]
=
torch
.
cos
(
position_embedding
[:,
:,
1
::
2
])
return
position_embedding
kantts/models/utils.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
from
distutils.version
import
LooseVersion
is_pytorch_17plus
=
LooseVersion
(
torch
.
__version__
)
>=
LooseVersion
(
"1.7"
)
def
init_weights
(
m
,
mean
=
0.0
,
std
=
0.01
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
"Conv"
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
mean
,
std
)
def
get_mask_from_lengths
(
lengths
,
max_len
=
None
):
batch_size
=
lengths
.
shape
[
0
]
if
max_len
is
None
:
max_len
=
torch
.
max
(
lengths
).
item
()
ids
=
(
torch
.
arange
(
0
,
max_len
).
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
to
(
lengths
.
device
)
)
mask
=
ids
>=
lengths
.
unsqueeze
(
1
).
expand
(
-
1
,
max_len
)
return
mask
kantts/preprocess/__init__.py
deleted
100644 → 0
View file @
8b4e9acd
kantts/preprocess/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/preprocess/__pycache__/fp_processor.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/preprocess/audio_processor/__init__.py
deleted
100644 → 0
View file @
8b4e9acd
kantts/preprocess/audio_processor/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment