Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
sambert-hifigan_pytorch
Commits
51782715
Commit
51782715
authored
Feb 23, 2024
by
liugh5
Browse files
update
parent
8b4e9acd
Changes
182
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3130 deletions
+0
-3130
kantts/configs/sambert_se_nsf_global_16k.yaml
kantts/configs/sambert_se_nsf_global_16k.yaml
+0
-111
kantts/configs/sambert_se_nsf_global_24k.yaml
kantts/configs/sambert_se_nsf_global_24k.yaml
+0
-111
kantts/configs/sambert_sichuan_16k.yaml
kantts/configs/sambert_sichuan_16k.yaml
+0
-106
kantts/configs/sybert.yaml
kantts/configs/sybert.yaml
+0
-69
kantts/datasets/__init__.py
kantts/datasets/__init__.py
+0
-0
kantts/datasets/__pycache__/__init__.cpython-38.pyc
kantts/datasets/__pycache__/__init__.cpython-38.pyc
+0
-0
kantts/datasets/__pycache__/dataset.cpython-38.pyc
kantts/datasets/__pycache__/dataset.cpython-38.pyc
+0
-0
kantts/datasets/data_types.py
kantts/datasets/data_types.py
+0
-36
kantts/datasets/dataset.py
kantts/datasets/dataset.py
+0
-1130
kantts/models/__init__.py
kantts/models/__init__.py
+0
-164
kantts/models/__pycache__/__init__.cpython-38.pyc
kantts/models/__pycache__/__init__.cpython-38.pyc
+0
-0
kantts/models/__pycache__/pqmf.cpython-38.pyc
kantts/models/__pycache__/pqmf.cpython-38.pyc
+0
-0
kantts/models/__pycache__/utils.cpython-38.pyc
kantts/models/__pycache__/utils.cpython-38.pyc
+0
-0
kantts/models/hifigan/__pycache__/hifigan.cpython-38.pyc
kantts/models/hifigan/__pycache__/hifigan.cpython-38.pyc
+0
-0
kantts/models/hifigan/__pycache__/layers.cpython-38.pyc
kantts/models/hifigan/__pycache__/layers.cpython-38.pyc
+0
-0
kantts/models/hifigan/hifigan.py
kantts/models/hifigan/hifigan.py
+0
-617
kantts/models/hifigan/layers.py
kantts/models/hifigan/layers.py
+0
-290
kantts/models/pqmf.py
kantts/models/pqmf.py
+0
-148
kantts/models/sambert/__init__.py
kantts/models/sambert/__init__.py
+0
-348
kantts/models/sambert/__pycache__/__init__.cpython-38.pyc
kantts/models/sambert/__pycache__/__init__.cpython-38.pyc
+0
-0
No files found.
kantts/configs/sambert_se_nsf_global_16k.yaml
deleted
100644 → 0
View file @
8b4e9acd
model_type
:
sambert
Model
:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT
:
params
:
max_len
:
800
embedding_dim
:
512
encoder_num_layers
:
8
encoder_num_heads
:
8
encoder_num_units
:
128
encoder_ffn_inner_dim
:
1024
encoder_dropout
:
0.1
encoder_attention_dropout
:
0.1
encoder_relu_dropout
:
0.1
encoder_projection_units
:
32
speaker_units
:
192
emotion_units
:
32
predictor_filter_size
:
41
predictor_fsmn_num_layers
:
3
predictor_num_memory_units
:
128
predictor_ffn_inner_dim
:
256
predictor_dropout
:
0.1
predictor_shift
:
0
predictor_lstm_units
:
128
dur_pred_prenet_units
:
[
128
,
128
]
dur_pred_lstm_units
:
128
decoder_prenet_units
:
[
256
,
256
]
decoder_num_layers
:
12
decoder_num_heads
:
8
decoder_num_units
:
128
decoder_ffn_inner_dim
:
1024
decoder_dropout
:
0.1
decoder_attention_dropout
:
0.1
decoder_relu_dropout
:
0.1
outputs_per_step
:
3
num_mels
:
82
postnet_filter_size
:
41
postnet_fsmn_num_layers
:
4
postnet_num_memory_units
:
256
postnet_ffn_inner_dim
:
512
postnet_dropout
:
0.1
postnet_shift
:
17
postnet_lstm_units
:
128
MAS
:
False
NSF
:
True
nsf_norm_type
:
global
nsf_f0_global_minimum
:
30.0
nsf_f0_global_maximum
:
730.0
SE
:
True
optimizer
:
type
:
Adam
params
:
lr
:
0.001
betas
:
[
0.9
,
0.98
]
eps
:
1.0e-9
weight_decay
:
0.0
scheduler
:
type
:
NoamLR
params
:
warmup_steps
:
4000
linguistic_unit
:
cleaners
:
english_cleaners
lfeat_type_list
:
sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list
:
F7
####################################################
# LOSS SETTING #
####################################################
Loss
:
MelReconLoss
:
enable
:
True
params
:
loss_type
:
mae
ProsodyReconLoss
:
enable
:
True
params
:
loss_type
:
mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
32
pin_memory
:
False
num_workers
:
4
# FIXME: set > 0 may stuck on macos
remove_short_samples
:
False
allow_cache
:
False
grad_norm
:
1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps
:
1760101
# Number of training steps.
save_interval_steps
:
100
# Interval steps to save checkpoint.
eval_interval_steps
:
1000000000000
# Interval steps to evaluate the network.
log_interval_steps
:
10
# Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results
:
4
# Number of results to be saved as intermediate results.
kantts/configs/sambert_se_nsf_global_24k.yaml
deleted
100644 → 0
View file @
8b4e9acd
model_type
:
sambert
Model
:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT
:
params
:
max_len
:
800
embedding_dim
:
512
encoder_num_layers
:
8
encoder_num_heads
:
8
encoder_num_units
:
128
encoder_ffn_inner_dim
:
1024
encoder_dropout
:
0.1
encoder_attention_dropout
:
0.1
encoder_relu_dropout
:
0.1
encoder_projection_units
:
32
speaker_units
:
192
emotion_units
:
32
predictor_filter_size
:
41
predictor_fsmn_num_layers
:
3
predictor_num_memory_units
:
128
predictor_ffn_inner_dim
:
256
predictor_dropout
:
0.1
predictor_shift
:
0
predictor_lstm_units
:
128
dur_pred_prenet_units
:
[
128
,
128
]
dur_pred_lstm_units
:
128
decoder_prenet_units
:
[
256
,
256
]
decoder_num_layers
:
12
decoder_num_heads
:
8
decoder_num_units
:
128
decoder_ffn_inner_dim
:
1024
decoder_dropout
:
0.1
decoder_attention_dropout
:
0.1
decoder_relu_dropout
:
0.1
outputs_per_step
:
3
num_mels
:
82
postnet_filter_size
:
41
postnet_fsmn_num_layers
:
4
postnet_num_memory_units
:
256
postnet_ffn_inner_dim
:
512
postnet_dropout
:
0.1
postnet_shift
:
17
postnet_lstm_units
:
128
MAS
:
False
NSF
:
True
nsf_norm_type
:
global
nsf_f0_global_minimum
:
30.0
nsf_f0_global_maximum
:
730.0
SE
:
True
optimizer
:
type
:
Adam
params
:
lr
:
0.001
betas
:
[
0.9
,
0.98
]
eps
:
1.0e-9
weight_decay
:
0.0
scheduler
:
type
:
NoamLR
params
:
warmup_steps
:
4000
linguistic_unit
:
cleaners
:
english_cleaners
lfeat_type_list
:
sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list
:
F7
####################################################
# LOSS SETTING #
####################################################
Loss
:
MelReconLoss
:
enable
:
True
params
:
loss_type
:
mae
ProsodyReconLoss
:
enable
:
True
params
:
loss_type
:
mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
32
pin_memory
:
False
num_workers
:
4
# FIXME: set > 0 may stuck on macos
remove_short_samples
:
False
allow_cache
:
False
grad_norm
:
1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps
:
2500000
# Number of training steps.
save_interval_steps
:
20000
# Interval steps to save checkpoint.
eval_interval_steps
:
1000000000000
# Interval steps to evaluate the network.
log_interval_steps
:
1000
# Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results
:
4
# Number of results to be saved as intermediate results.
kantts/configs/sambert_sichuan_16k.yaml
deleted
100644 → 0
View file @
8b4e9acd
model_type
:
sambert
Model
:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT
:
params
:
max_len
:
800
embedding_dim
:
512
encoder_num_layers
:
8
encoder_num_heads
:
8
encoder_num_units
:
128
encoder_ffn_inner_dim
:
1024
encoder_dropout
:
0.1
encoder_attention_dropout
:
0.1
encoder_relu_dropout
:
0.1
encoder_projection_units
:
32
speaker_units
:
32
emotion_units
:
32
predictor_filter_size
:
41
predictor_fsmn_num_layers
:
3
predictor_num_memory_units
:
128
predictor_ffn_inner_dim
:
256
predictor_dropout
:
0.1
predictor_shift
:
0
predictor_lstm_units
:
128
dur_pred_prenet_units
:
[
128
,
128
]
dur_pred_lstm_units
:
128
decoder_prenet_units
:
[
256
,
256
]
decoder_num_layers
:
12
decoder_num_heads
:
8
decoder_num_units
:
128
decoder_ffn_inner_dim
:
1024
decoder_dropout
:
0.1
decoder_attention_dropout
:
0.1
decoder_relu_dropout
:
0.1
outputs_per_step
:
3
num_mels
:
80
postnet_filter_size
:
41
postnet_fsmn_num_layers
:
4
postnet_num_memory_units
:
256
postnet_ffn_inner_dim
:
512
postnet_dropout
:
0.1
postnet_shift
:
17
postnet_lstm_units
:
128
MAS
:
False
optimizer
:
type
:
Adam
params
:
lr
:
0.001
betas
:
[
0.9
,
0.98
]
eps
:
1.0e-9
weight_decay
:
0.0
scheduler
:
type
:
NoamLR
params
:
warmup_steps
:
4000
linguistic_unit
:
cleaners
:
english_cleaners
lfeat_type_list
:
sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list
:
xiaoyue
language
:
Sichuan
####################################################
# LOSS SETTING #
####################################################
Loss
:
MelReconLoss
:
enable
:
True
params
:
loss_type
:
mae
ProsodyReconLoss
:
enable
:
True
params
:
loss_type
:
mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
32
pin_memory
:
False
num_workers
:
4
# FIXME: set > 0 may stuck on macos
remove_short_samples
:
False
allow_cache
:
True
grad_norm
:
1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps
:
1000000
# Number of training steps.
save_interval_steps
:
20000
# Interval steps to save checkpoint.
eval_interval_steps
:
10000
# Interval steps to evaluate the network.
log_interval_steps
:
1000
# Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results
:
4
# Number of results to be saved as intermediate results.
kantts/configs/sybert.yaml
deleted
100644 → 0
View file @
8b4e9acd
model_type
:
sybert
Model
:
#########################################################
# TextsyBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsTextsyBERT
:
params
:
max_len
:
800
embedding_dim
:
512
encoder_num_layers
:
8
encoder_num_heads
:
8
encoder_num_units
:
128
encoder_ffn_inner_dim
:
1024
encoder_dropout
:
0.1
encoder_attention_dropout
:
0.1
encoder_relu_dropout
:
0.1
encoder_projection_units
:
32
mask_ratio
:
0.3
optimizer
:
type
:
Adam
params
:
lr
:
0.0001
betas
:
[
0.9
,
0.98
]
eps
:
1.0e-9
weight_decay
:
0.0
scheduler
:
type
:
NoamLR
params
:
warmup_steps
:
10000
linguistic_unit
:
cleaners
:
english_cleaners
lfeat_type_list
:
sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list
:
F7
####################################################
# LOSS SETTING #
####################################################
Loss
:
SeqCELoss
:
enable
:
True
params
:
loss_type
:
ce
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
32
pin_memory
:
False
num_workers
:
4
# FIXME: set > 0 may stuck on macos
remove_short_samples
:
False
allow_cache
:
True
grad_norm
:
1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps
:
1000000
# Number of training steps.
save_interval_steps
:
20000
# Interval steps to save checkpoint.
eval_interval_steps
:
10000
# Interval steps to evaluate the network.
log_interval_steps
:
1000
# Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results
:
4
# Number of results to be saved as intermediate results.
kantts/datasets/__init__.py
deleted
100644 → 0
View file @
8b4e9acd
kantts/datasets/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/datasets/__pycache__/dataset.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/datasets/data_types.py
deleted
100644 → 0
View file @
8b4e9acd
import
numpy
as
np
from
scipy.io
import
wavfile
# TODO: add your own data type here as you need.
DATA_TYPE_DICT
=
{
"txt"
:
{
"load_func"
:
np
.
loadtxt
,
"desc"
:
"plain txt file or readable by np.loadtxt"
,
},
"wav"
:
{
"load_func"
:
lambda
x
:
wavfile
.
read
(
x
)[
1
],
"desc"
:
"wav file or readable by soundfile.read"
,
},
"npy"
:
{
"load_func"
:
np
.
load
,
"desc"
:
"any .npy format file"
,
},
# PCM data type can be loaded by binary format
"bin_f32"
:
{
"load_func"
:
lambda
x
:
np
.
fromfile
(
x
,
dtype
=
np
.
float32
),
"desc"
:
"binary file with float32 format"
,
},
"bin_f64"
:
{
"load_func"
:
lambda
x
:
np
.
fromfile
(
x
,
dtype
=
np
.
float64
),
"desc"
:
"binary file with float64 format"
,
},
"bin_i32"
:
{
"load_func"
:
lambda
x
:
np
.
fromfile
(
x
,
dtype
=
np
.
int32
),
"desc"
:
"binary file with int32 format"
,
},
"bin_i16"
:
{
"load_func"
:
lambda
x
:
np
.
fromfile
(
x
,
dtype
=
np
.
int16
),
"desc"
:
"binary file with int16 format"
,
},
}
kantts/datasets/dataset.py
deleted
100644 → 0
View file @
8b4e9acd
import
os
import
torch
import
glob
import
logging
from
multiprocessing
import
Manager
import
librosa
import
numpy
as
np
import
random
import
functools
from
tqdm
import
tqdm
import
math
from
kantts.utils.ling_unit.ling_unit
import
KanTtsLinguisticUnit
,
emotion_types
from
scipy.stats
import
betabinom
DATASET_RANDOM_SEED
=
1234
torch
.
multiprocessing
.
set_sharing_strategy
(
"file_system"
)
@
functools
.
lru_cache
(
maxsize
=
256
)
def
beta_binomial_prior_distribution
(
phoneme_count
,
mel_count
,
scaling
=
1.0
):
P
=
phoneme_count
M
=
mel_count
x
=
np
.
arange
(
0
,
P
)
mel_text_probs
=
[]
for
i
in
range
(
1
,
M
+
1
):
a
,
b
=
scaling
*
i
,
scaling
*
(
M
+
1
-
i
)
rv
=
betabinom
(
P
,
a
,
b
)
mel_i_prob
=
rv
.
pmf
(
x
)
mel_text_probs
.
append
(
mel_i_prob
)
return
torch
.
tensor
(
np
.
array
(
mel_text_probs
))
class
Padder
(
object
):
def
__init__
(
self
):
super
(
Padder
,
self
).
__init__
()
pass
def
_pad1D
(
self
,
x
,
length
,
pad
):
return
np
.
pad
(
x
,
(
0
,
length
-
x
.
shape
[
0
]),
mode
=
"constant"
,
constant_values
=
pad
)
def
_pad2D
(
self
,
x
,
length
,
pad
):
return
np
.
pad
(
x
,
[(
0
,
length
-
x
.
shape
[
0
]),
(
0
,
0
)],
mode
=
"constant"
,
constant_values
=
pad
)
def
_pad_durations
(
self
,
duration
,
max_in_len
,
max_out_len
):
framenum
=
np
.
sum
(
duration
)
symbolnum
=
duration
.
shape
[
0
]
if
framenum
<
max_out_len
:
padframenum
=
max_out_len
-
framenum
duration
=
np
.
insert
(
duration
,
symbolnum
,
values
=
padframenum
,
axis
=
0
)
duration
=
np
.
insert
(
duration
,
symbolnum
+
1
,
values
=
[
0
]
*
(
max_in_len
-
symbolnum
-
1
),
axis
=
0
,
)
else
:
if
symbolnum
<
max_in_len
:
duration
=
np
.
insert
(
duration
,
symbolnum
,
values
=
[
0
]
*
(
max_in_len
-
symbolnum
),
axis
=
0
)
return
duration
def
_round_up
(
self
,
x
,
multiple
):
remainder
=
x
%
multiple
return
x
if
remainder
==
0
else
x
+
multiple
-
remainder
def
_prepare_scalar_inputs
(
self
,
inputs
,
max_len
,
pad
):
return
torch
.
from_numpy
(
np
.
stack
([
self
.
_pad1D
(
x
,
max_len
,
pad
)
for
x
in
inputs
])
)
def
_prepare_targets
(
self
,
targets
,
max_len
,
pad
):
return
torch
.
from_numpy
(
np
.
stack
([
self
.
_pad2D
(
t
,
max_len
,
pad
)
for
t
in
targets
])
).
float
()
def
_prepare_durations
(
self
,
durations
,
max_in_len
,
max_out_len
):
return
torch
.
from_numpy
(
np
.
stack
(
[
self
.
_pad_durations
(
t
,
max_in_len
,
max_out_len
)
for
t
in
durations
]
)
).
long
()
class
Voc_Dataset
(
torch
.
utils
.
data
.
Dataset
):
"""
provide (mel, audio) data pair
"""
def
__init__
(
self
,
metafile
,
root_dir
,
config
,
):
self
.
meta
=
[]
self
.
config
=
config
self
.
sampling_rate
=
config
[
"audio_config"
][
"sampling_rate"
]
self
.
n_fft
=
config
[
"audio_config"
][
"n_fft"
]
self
.
hop_length
=
config
[
"audio_config"
][
"hop_length"
]
self
.
batch_max_steps
=
config
[
"batch_max_steps"
]
self
.
batch_max_frames
=
self
.
batch_max_steps
//
self
.
hop_length
self
.
aux_context_window
=
0
# TODO: make it configurable
self
.
start_offset
=
self
.
aux_context_window
self
.
end_offset
=
-
(
self
.
batch_max_frames
+
self
.
aux_context_window
)
self
.
nsf_enable
=
(
config
[
"Model"
][
"Generator"
][
"params"
].
get
(
"nsf_params"
,
None
)
is
not
None
)
if
self
.
nsf_enable
:
self
.
nsf_norm_type
=
config
[
"Model"
][
"Generator"
][
"params"
][
"nsf_params"
].
get
(
"nsf_norm_type"
,
'"mean_std'
)
if
self
.
nsf_norm_type
==
"global"
:
self
.
nsf_f0_global_minimum
=
config
[
"Model"
][
"Generator"
][
"params"
][
"nsf_params"
].
get
(
"nsf_f0_global_minimum"
,
30.0
)
self
.
nsf_f0_global_maximum
=
config
[
"Model"
][
"Generator"
][
"params"
][
"nsf_params"
].
get
(
"nsf_f0_global_maximum"
,
730.0
)
if
not
isinstance
(
metafile
,
list
):
metafile
=
[
metafile
]
if
not
isinstance
(
root_dir
,
list
):
root_dir
=
[
root_dir
]
for
meta_file
,
data_dir
in
zip
(
metafile
,
root_dir
):
if
not
os
.
path
.
exists
(
meta_file
):
logging
.
error
(
"meta file not found: {}"
.
format
(
meta_file
))
raise
ValueError
(
"[Voc_Dataset] meta file: {} not found"
.
format
(
meta_file
)
)
if
not
os
.
path
.
exists
(
data_dir
):
logging
.
error
(
"data directory not found: {}"
.
format
(
data_dir
))
raise
ValueError
(
"[Voc_Dataset] data dir: {} not found"
.
format
(
data_dir
)
)
self
.
meta
.
extend
(
self
.
load_meta
(
meta_file
,
data_dir
))
# Load from training data directory
if
len
(
self
.
meta
)
==
0
and
isinstance
(
root_dir
,
str
):
wav_dir
=
os
.
path
.
join
(
root_dir
,
"wav"
)
mel_dir
=
os
.
path
.
join
(
root_dir
,
"mel"
)
if
not
os
.
path
.
exists
(
wav_dir
)
or
not
os
.
path
.
exists
(
mel_dir
):
raise
ValueError
(
"wav or mel directory not found"
)
self
.
meta
.
extend
(
self
.
load_meta_from_dir
(
wav_dir
,
mel_dir
))
elif
len
(
self
.
meta
)
==
0
and
isinstance
(
root_dir
,
list
):
for
d
in
root_dir
:
wav_dir
=
os
.
path
.
join
(
d
,
"wav"
)
mel_dir
=
os
.
path
.
join
(
d
,
"mel"
)
if
not
os
.
path
.
exists
(
wav_dir
)
or
not
os
.
path
.
exists
(
mel_dir
):
raise
ValueError
(
"wav or mel directory not found"
)
self
.
meta
.
extend
(
self
.
load_meta_from_dir
(
wav_dir
,
mel_dir
))
self
.
allow_cache
=
config
[
"allow_cache"
]
if
self
.
allow_cache
:
self
.
manager
=
Manager
()
self
.
caches
=
self
.
manager
.
list
()
self
.
caches
+=
[()
for
_
in
range
(
len
(
self
.
meta
))]
@
staticmethod
def
gen_metafile
(
wav_dir
,
out_dir
,
split_ratio
=
0.98
):
wav_files
=
glob
.
glob
(
os
.
path
.
join
(
wav_dir
,
"*.wav"
))
frame_f0_dir
=
os
.
path
.
join
(
out_dir
,
"frame_f0"
)
frame_uv_dir
=
os
.
path
.
join
(
out_dir
,
"frame_uv"
)
mel_dir
=
os
.
path
.
join
(
out_dir
,
"mel"
)
random
.
seed
(
DATASET_RANDOM_SEED
)
random
.
shuffle
(
wav_files
)
num_train
=
int
(
len
(
wav_files
)
*
split_ratio
)
-
1
with
open
(
os
.
path
.
join
(
out_dir
,
"train.lst"
),
"w"
)
as
f
:
for
wav_file
in
wav_files
[:
num_train
]:
index
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
wav_file
))[
0
]
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_f0_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_uv_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
mel_dir
,
index
+
".npy"
))
):
continue
f
.
write
(
"{}
\n
"
.
format
(
index
))
with
open
(
os
.
path
.
join
(
out_dir
,
"valid.lst"
),
"w"
)
as
f
:
for
wav_file
in
wav_files
[
num_train
:]:
index
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
wav_file
))[
0
]
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_f0_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_uv_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
mel_dir
,
index
+
".npy"
))
):
continue
f
.
write
(
"{}
\n
"
.
format
(
index
))
def
load_meta
(
self
,
metafile
,
data_dir
):
with
open
(
metafile
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
wav_dir
=
os
.
path
.
join
(
data_dir
,
"wav"
)
mel_dir
=
os
.
path
.
join
(
data_dir
,
"mel"
)
frame_f0_dir
=
os
.
path
.
join
(
data_dir
,
"frame_f0"
)
frame_uv_dir
=
os
.
path
.
join
(
data_dir
,
"frame_uv"
)
if
not
os
.
path
.
exists
(
wav_dir
)
or
not
os
.
path
.
exists
(
mel_dir
):
raise
ValueError
(
"wav or mel directory not found"
)
items
=
[]
logging
.
info
(
"Loading metafile..."
)
for
name
in
tqdm
(
lines
):
name
=
name
.
strip
()
mel_file
=
os
.
path
.
join
(
mel_dir
,
name
+
".npy"
)
wav_file
=
os
.
path
.
join
(
wav_dir
,
name
+
".wav"
)
frame_f0_file
=
os
.
path
.
join
(
frame_f0_dir
,
name
+
".npy"
)
frame_uv_file
=
os
.
path
.
join
(
frame_uv_dir
,
name
+
".npy"
)
items
.
append
((
wav_file
,
mel_file
,
frame_f0_file
,
frame_uv_file
))
return
items
def
load_meta_from_dir
(
self
,
wav_dir
,
mel_dir
):
wav_files
=
glob
.
glob
(
os
.
path
.
join
(
wav_dir
,
"*.wav"
))
items
=
[]
for
wav_file
in
wav_files
:
mel_file
=
os
.
path
.
join
(
mel_dir
,
os
.
path
.
basename
(
wav_file
))
if
os
.
path
.
exists
(
mel_file
):
items
.
append
((
wav_file
,
mel_file
))
return
items
def
__len__
(
self
):
return
len
(
self
.
meta
)
def
__getitem__
(
self
,
idx
):
if
self
.
allow_cache
and
len
(
self
.
caches
[
idx
])
!=
0
:
return
self
.
caches
[
idx
]
wav_file
,
mel_file
,
frame_f0_file
,
frame_uv_file
=
self
.
meta
[
idx
]
f0_mean_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
frame_f0_file
)),
"f0"
,
"f0_mean.txt"
)
f0_std_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
frame_f0_file
)),
"f0"
,
"f0_std.txt"
)
wav_data
=
librosa
.
core
.
load
(
wav_file
,
sr
=
self
.
sampling_rate
)[
0
]
mel_data
=
np
.
load
(
mel_file
)
if
self
.
nsf_enable
:
# denorm f0; default frame_f0_data using mean_std norm
frame_f0_data
=
np
.
load
(
frame_f0_file
).
reshape
(
-
1
,
1
)
f0_mean
=
np
.
loadtxt
(
f0_mean_file
)
f0_std
=
np
.
loadtxt
(
f0_std_file
)
frame_f0_data
=
frame_f0_data
*
f0_std
+
f0_mean
frame_uv_data
=
np
.
load
(
frame_uv_file
).
reshape
(
-
1
,
1
)
mel_data
=
np
.
concatenate
((
mel_data
,
frame_f0_data
,
frame_uv_data
),
axis
=
1
)
# make sure mel_data length greater than batch_max_frames at least 1 frame
if
mel_data
.
shape
[
0
]
<=
self
.
batch_max_frames
:
mel_data
=
np
.
concatenate
(
(
mel_data
,
np
.
zeros
(
(
self
.
batch_max_frames
-
mel_data
.
shape
[
0
]
+
1
,
mel_data
.
shape
[
1
],
)
),
),
axis
=
0
,
)
wav_cache
=
np
.
zeros
(
mel_data
.
shape
[
0
]
*
self
.
hop_length
,
dtype
=
np
.
float32
)
wav_cache
[:
len
(
wav_data
)]
=
wav_data
wav_data
=
wav_cache
else
:
# make sure the audio length and feature length are matched
wav_data
=
np
.
pad
(
wav_data
,
(
0
,
self
.
n_fft
),
mode
=
"reflect"
)
wav_data
=
wav_data
[:
len
(
mel_data
)
*
self
.
hop_length
]
assert
len
(
mel_data
)
*
self
.
hop_length
==
len
(
wav_data
)
if
self
.
allow_cache
:
self
.
caches
[
idx
]
=
(
wav_data
,
mel_data
)
return
(
wav_data
,
mel_data
)
def
collate_fn
(
self
,
batch
):
wav_data
,
mel_data
=
[
item
[
0
]
for
item
in
batch
],
[
item
[
1
]
for
item
in
batch
]
mel_lengths
=
[
len
(
mel
)
for
mel
in
mel_data
]
start_frames
=
np
.
array
(
[
np
.
random
.
randint
(
self
.
start_offset
,
length
+
self
.
end_offset
)
for
length
in
mel_lengths
]
)
wav_start
=
start_frames
*
self
.
hop_length
wav_end
=
wav_start
+
self
.
batch_max_steps
# aux window works as padding
mel_start
=
start_frames
-
self
.
aux_context_window
mel_end
=
mel_start
+
self
.
batch_max_frames
+
self
.
aux_context_window
wav_batch
=
[
x
[
start
:
end
]
for
x
,
start
,
end
in
zip
(
wav_data
,
wav_start
,
wav_end
)
]
mel_batch
=
[
c
[
start
:
end
]
for
c
,
start
,
end
in
zip
(
mel_data
,
mel_start
,
mel_end
)
]
# (B, 1, T)
wav_batch
=
torch
.
tensor
(
np
.
asarray
(
wav_batch
),
dtype
=
torch
.
float32
).
unsqueeze
(
1
)
# (B, C, T)
mel_batch
=
torch
.
tensor
(
np
.
asarray
(
mel_batch
),
dtype
=
torch
.
float32
).
transpose
(
2
,
1
)
return
wav_batch
,
mel_batch
def
get_voc_datasets
(
config
,
root_dir
,
split_ratio
=
0.98
,
):
if
isinstance
(
root_dir
,
str
):
root_dir
=
[
root_dir
]
train_meta_lst
=
[]
valid_meta_lst
=
[]
for
data_dir
in
root_dir
:
train_meta
=
os
.
path
.
join
(
data_dir
,
"train.lst"
)
valid_meta
=
os
.
path
.
join
(
data_dir
,
"valid.lst"
)
if
not
os
.
path
.
exists
(
train_meta
)
or
not
os
.
path
.
exists
(
valid_meta
):
Voc_Dataset
.
gen_metafile
(
os
.
path
.
join
(
data_dir
,
"wav"
),
data_dir
,
split_ratio
)
train_meta_lst
.
append
(
train_meta
)
valid_meta_lst
.
append
(
valid_meta
)
train_dataset
=
Voc_Dataset
(
train_meta_lst
,
root_dir
,
config
,
)
valid_dataset
=
Voc_Dataset
(
valid_meta_lst
[:
50
],
root_dir
,
config
,
)
return
train_dataset
,
valid_dataset
# TODO(Yuxuan): refine the logic, you'd better not use emotion tag, it's ambiguous.
def
get_fp_label
(
aug_ling_txt
):
token_lst
=
aug_ling_txt
.
split
(
" "
)
emo_lst
=
[
token
.
strip
(
"{}"
).
split
(
"$"
)[
4
]
for
token
in
token_lst
]
syllable_lst
=
[
token
.
strip
(
"{}"
).
split
(
"$"
)[
0
]
for
token
in
token_lst
]
# EOS token append
emo_lst
.
append
(
emotion_types
[
0
])
syllable_lst
.
append
(
"EOS"
)
# According to the original emotion tag, set each token's fp label.
if
emo_lst
[
0
]
!=
emotion_types
[
3
]:
emo_lst
[
0
]
=
emotion_types
[
0
]
emo_lst
[
1
]
=
emotion_types
[
0
]
for
i
in
range
(
len
(
emo_lst
)
-
2
,
1
,
-
1
):
if
emo_lst
[
i
]
!=
emotion_types
[
3
]
and
emo_lst
[
i
-
1
]
!=
emotion_types
[
3
]:
emo_lst
[
i
]
=
emotion_types
[
0
]
elif
emo_lst
[
i
]
!=
emotion_types
[
3
]
and
emo_lst
[
i
-
1
]
==
emotion_types
[
3
]:
emo_lst
[
i
]
=
emotion_types
[
3
]
if
syllable_lst
[
i
-
2
]
==
"ga"
:
emo_lst
[
i
+
1
]
=
emotion_types
[
1
]
elif
syllable_lst
[
i
-
2
]
==
"ge"
and
syllable_lst
[
i
-
1
]
==
"en_c"
:
emo_lst
[
i
+
1
]
=
emotion_types
[
2
]
else
:
emo_lst
[
i
+
1
]
=
emotion_types
[
4
]
fp_label
=
[]
for
i
in
range
(
len
(
emo_lst
)):
if
emo_lst
[
i
]
==
emotion_types
[
0
]:
fp_label
.
append
(
0
)
elif
emo_lst
[
i
]
==
emotion_types
[
1
]:
fp_label
.
append
(
1
)
elif
emo_lst
[
i
]
==
emotion_types
[
2
]:
fp_label
.
append
(
2
)
elif
emo_lst
[
i
]
==
emotion_types
[
3
]:
continue
elif
emo_lst
[
i
]
==
emotion_types
[
4
]:
fp_label
.
append
(
3
)
else
:
pass
return
np
.
array
(
fp_label
)
class
AM_Dataset
(
torch
.
utils
.
data
.
Dataset
):
"""
provide (ling, emo, speaker, mel) pair
"""
def
__init__
(
self
,
config
,
metafile
,
root_dir
,
allow_cache
=
False
,
):
self
.
meta
=
[]
self
.
config
=
config
self
.
with_duration
=
True
self
.
nsf_enable
=
self
.
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"NSF"
,
False
)
if
self
.
nsf_enable
:
self
.
nsf_norm_type
=
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"nsf_norm_type"
,
"mean_std"
)
if
self
.
nsf_norm_type
==
"global"
:
self
.
nsf_f0_global_minimum
=
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"nsf_f0_global_minimum"
,
30.0
)
self
.
nsf_f0_global_maximum
=
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"nsf_f0_global_maximum"
,
730.0
)
self
.
se_enable
=
self
.
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"SE"
,
False
)
self
.
fp_enable
=
self
.
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"FP"
,
False
)
self
.
mas_enable
=
self
.
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"MAS"
,
False
)
if
not
isinstance
(
metafile
,
list
):
metafile
=
[
metafile
]
if
not
isinstance
(
root_dir
,
list
):
root_dir
=
[
root_dir
]
for
meta_file
,
data_dir
in
zip
(
metafile
,
root_dir
):
if
not
os
.
path
.
exists
(
meta_file
):
logging
.
error
(
"meta file not found: {}"
.
format
(
meta_file
))
raise
ValueError
(
"[AM_Dataset] meta file: {} not found"
.
format
(
meta_file
)
)
if
not
os
.
path
.
exists
(
data_dir
):
logging
.
error
(
"data dir not found: {}"
.
format
(
data_dir
))
raise
ValueError
(
"[AM_Dataset] data dir: {} not found"
.
format
(
data_dir
))
self
.
meta
.
extend
(
self
.
load_meta
(
meta_file
,
data_dir
))
self
.
allow_cache
=
allow_cache
self
.
ling_unit
=
KanTtsLinguisticUnit
(
config
)
self
.
padder
=
Padder
()
self
.
r
=
self
.
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
][
"outputs_per_step"
]
# TODO: feat window
if
allow_cache
:
self
.
manager
=
Manager
()
self
.
caches
=
self
.
manager
.
list
()
self
.
caches
+=
[()
for
_
in
range
(
len
(
self
.
meta
))]
def
__len__
(
self
):
return
len
(
self
.
meta
)
def
__getitem__
(
self
,
idx
):
if
self
.
allow_cache
and
len
(
self
.
caches
[
idx
])
!=
0
:
return
self
.
caches
[
idx
]
(
ling_txt
,
mel_file
,
dur_file
,
f0_file
,
energy_file
,
frame_f0_file
,
frame_uv_file
,
aug_ling_txt
,
se_path
,
)
=
self
.
meta
[
idx
]
f0_mean_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
frame_f0_file
)),
"f0"
,
"f0_mean.txt"
)
f0_std_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
frame_f0_file
)),
"f0"
,
"f0_std.txt"
)
ling_data
=
self
.
ling_unit
.
encode_symbol_sequence
(
ling_txt
)
mel_data
=
np
.
load
(
mel_file
)
dur_data
=
np
.
load
(
dur_file
)
if
dur_file
is
not
None
else
None
f0_data
=
np
.
load
(
f0_file
)
energy_data
=
np
.
load
(
energy_file
)
se_data
=
np
.
load
(
se_path
)
if
self
.
se_enable
else
None
# generate fp position label according to fpadd_meta
if
self
.
fp_enable
and
aug_ling_txt
is
not
None
:
fp_label
=
get_fp_label
(
aug_ling_txt
)
else
:
fp_label
=
None
if
self
.
with_duration
:
attn_prior
=
None
else
:
attn_prior
=
beta_binomial_prior_distribution
(
len
(
ling_data
[
0
]),
mel_data
.
shape
[
0
]
)
# Concat frame-level f0 and uv to mel_data
if
self
.
nsf_enable
:
# origin f0 data is mean std normed
frame_f0_data
=
np
.
load
(
frame_f0_file
).
reshape
(
-
1
,
1
)
# default f0 data is mean std normed; re-norm here
if
self
.
nsf_norm_type
==
"global"
:
# denorm f0
f0_mean
=
np
.
loadtxt
(
f0_mean_file
)
f0_std
=
np
.
loadtxt
(
f0_std_file
)
f0_origin
=
frame_f0_data
*
f0_std
+
f0_mean
# renorm f0
frame_f0_data
=
(
f0_origin
-
self
.
nsf_f0_global_minimum
)
/
(
self
.
nsf_f0_global_maximum
-
self
.
nsf_f0_global_minimum
)
frame_uv_data
=
np
.
load
(
frame_uv_file
).
reshape
(
-
1
,
1
)
mel_data
=
np
.
concatenate
([
mel_data
,
frame_f0_data
,
frame_uv_data
],
axis
=
1
)
if
self
.
allow_cache
:
self
.
caches
[
idx
]
=
(
ling_data
,
mel_data
,
dur_data
,
f0_data
,
energy_data
,
attn_prior
,
fp_label
,
se_data
,
)
return
(
ling_data
,
mel_data
,
dur_data
,
f0_data
,
energy_data
,
attn_prior
,
fp_label
,
se_data
,
)
def
load_meta
(
self
,
metafile
,
data_dir
):
with
open
(
metafile
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
aug_ling_dict
=
{}
if
self
.
fp_enable
:
add_fp_metafile
=
metafile
.
replace
(
"fprm"
,
"fpadd"
)
with
open
(
add_fp_metafile
,
"r"
)
as
f
:
fpadd_lines
=
f
.
readlines
()
for
line
in
fpadd_lines
:
index
,
aug_ling_txt
=
line
.
split
(
"
\t
"
)
aug_ling_dict
[
index
]
=
aug_ling_txt
mel_dir
=
os
.
path
.
join
(
data_dir
,
"mel"
)
dur_dir
=
os
.
path
.
join
(
data_dir
,
"duration"
)
f0_dir
=
os
.
path
.
join
(
data_dir
,
"f0"
)
energy_dir
=
os
.
path
.
join
(
data_dir
,
"energy"
)
frame_f0_dir
=
os
.
path
.
join
(
data_dir
,
"frame_f0"
)
frame_uv_dir
=
os
.
path
.
join
(
data_dir
,
"frame_uv"
)
se_dir
=
os
.
path
.
join
(
data_dir
,
"se"
)
if
self
.
mas_enable
:
self
.
with_duration
=
False
else
:
self
.
with_duration
=
os
.
path
.
exists
(
dur_dir
)
items
=
[]
logging
.
info
(
"Loading metafile..."
)
for
line
in
tqdm
(
lines
):
line
=
line
.
strip
()
index
,
ling_txt
=
line
.
split
(
"
\t
"
)
mel_file
=
os
.
path
.
join
(
mel_dir
,
index
+
".npy"
)
if
self
.
with_duration
:
dur_file
=
os
.
path
.
join
(
dur_dir
,
index
+
".npy"
)
else
:
dur_file
=
None
f0_file
=
os
.
path
.
join
(
f0_dir
,
index
+
".npy"
)
energy_file
=
os
.
path
.
join
(
energy_dir
,
index
+
".npy"
)
frame_f0_file
=
os
.
path
.
join
(
frame_f0_dir
,
index
+
".npy"
)
frame_uv_file
=
os
.
path
.
join
(
frame_uv_dir
,
index
+
".npy"
)
aug_ling_txt
=
aug_ling_dict
.
get
(
index
,
None
)
if
self
.
fp_enable
and
aug_ling_txt
is
None
:
logging
.
warning
(
f
"Missing fpadd meta for
{
index
}
"
)
continue
se_path
=
os
.
path
.
join
(
se_dir
,
"se.npy"
)
if
self
.
se_enable
:
if
not
os
.
path
.
exists
(
se_path
):
logging
.
warning
(
"Missing se meta"
)
continue
items
.
append
(
(
ling_txt
,
mel_file
,
dur_file
,
f0_file
,
energy_file
,
frame_f0_file
,
frame_uv_file
,
aug_ling_txt
,
se_path
,
)
)
return
items
def
load_fpadd_meta
(
self
,
metafile
):
with
open
(
metafile
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
items
=
[]
logging
.
info
(
"Loading fpadd metafile..."
)
for
line
in
tqdm
(
lines
):
line
=
line
.
strip
()
index
,
ling_txt
=
line
.
split
(
"
\t
"
)
items
.
append
((
ling_txt
,))
return
items
@
staticmethod
def
gen_metafile
(
raw_meta_file
,
out_dir
,
train_meta_file
,
valid_meta_file
,
badlist
=
None
,
split_ratio
=
0.98
,
se_enable
=
False
,
):
with
open
(
raw_meta_file
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
se_dir
=
os
.
path
.
join
(
out_dir
,
"se"
)
frame_f0_dir
=
os
.
path
.
join
(
out_dir
,
"frame_f0"
)
frame_uv_dir
=
os
.
path
.
join
(
out_dir
,
"frame_uv"
)
mel_dir
=
os
.
path
.
join
(
out_dir
,
"mel"
)
duration_dir
=
os
.
path
.
join
(
out_dir
,
"duration"
)
random
.
seed
(
DATASET_RANDOM_SEED
)
random
.
shuffle
(
lines
)
num_train
=
int
(
len
(
lines
)
*
split_ratio
)
-
1
with
open
(
train_meta_file
,
"w"
)
as
f
:
for
line
in
lines
[:
num_train
]:
index
=
line
.
split
(
"
\t
"
)[
0
]
if
badlist
is
not
None
and
index
in
badlist
:
continue
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_f0_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_uv_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
mel_dir
,
index
+
".npy"
))
):
continue
if
os
.
path
.
exists
(
duration_dir
)
and
not
os
.
path
.
exists
(
os
.
path
.
join
(
duration_dir
,
index
+
".npy"
)
):
continue
if
se_enable
:
if
os
.
path
.
exists
(
se_dir
)
and
not
os
.
path
.
exists
(
os
.
path
.
join
(
se_dir
,
"se.npy"
)
):
continue
f
.
write
(
line
)
with
open
(
valid_meta_file
,
"w"
)
as
f
:
for
line
in
lines
[
num_train
:]:
index
=
line
.
split
(
"
\t
"
)[
0
]
if
badlist
is
not
None
and
index
in
badlist
:
continue
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_f0_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
frame_uv_dir
,
index
+
".npy"
))
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
mel_dir
,
index
+
".npy"
))
):
continue
if
os
.
path
.
exists
(
duration_dir
)
and
not
os
.
path
.
exists
(
os
.
path
.
join
(
duration_dir
,
index
+
".npy"
)
):
continue
if
se_enable
:
if
os
.
path
.
exists
(
se_dir
)
and
not
os
.
path
.
exists
(
os
.
path
.
join
(
se_dir
,
"se.npy"
)
):
continue
f
.
write
(
line
)
# TODO: implement collate_fn
def
collate_fn
(
self
,
batch
):
data_dict
=
{}
max_input_length
=
max
((
len
(
x
[
0
][
0
])
for
x
in
batch
))
if
self
.
with_duration
:
max_dur_length
=
max
((
x
[
2
].
shape
[
0
]
for
x
in
batch
))
+
1
lfeat_type_index
=
0
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
if
self
.
ling_unit
.
using_byte
():
# for byte-based model only
inputs_byte_index
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
data_dict
[
"input_lings"
]
=
torch
.
stack
([
inputs_byte_index
],
dim
=
2
)
else
:
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
inputs_sy
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# tone
lfeat_type_index
=
lfeat_type_index
+
1
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
inputs_tone
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# syllable_flag
lfeat_type_index
=
lfeat_type_index
+
1
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
inputs_syllable_flag
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# word_segment
lfeat_type_index
=
lfeat_type_index
+
1
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
inputs_ws
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
data_dict
[
"input_lings"
]
=
torch
.
stack
(
[
inputs_sy
,
inputs_tone
,
inputs_syllable_flag
,
inputs_ws
],
dim
=
2
)
# emotion category
lfeat_type_index
=
lfeat_type_index
+
1
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
data_dict
[
"input_emotions"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# speaker category
lfeat_type_index
=
lfeat_type_index
+
1
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
lfeat_type_index
]
if
self
.
se_enable
:
data_dict
[
"input_speakers"
]
=
self
.
padder
.
_prepare_targets
(
[
x
[
7
].
repeat
(
len
(
x
[
0
][
0
]),
axis
=
0
)
for
x
in
batch
],
max_input_length
,
0.0
,
)
else
:
data_dict
[
"input_speakers"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
lfeat_type_index
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# fp label category
if
self
.
fp_enable
:
data_dict
[
"fp_label"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
6
]
for
x
in
batch
],
max_input_length
,
0
,
).
long
()
data_dict
[
"valid_input_lengths"
]
=
torch
.
as_tensor
(
[
len
(
x
[
0
][
0
])
-
1
for
x
in
batch
],
dtype
=
torch
.
long
)
# 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict
[
"valid_output_lengths"
]
=
torch
.
as_tensor
(
[
len
(
x
[
1
])
for
x
in
batch
],
dtype
=
torch
.
long
)
max_output_length
=
torch
.
max
(
data_dict
[
"valid_output_lengths"
]).
item
()
max_output_round_length
=
self
.
padder
.
_round_up
(
max_output_length
,
self
.
r
)
data_dict
[
"mel_targets"
]
=
self
.
padder
.
_prepare_targets
(
[
x
[
1
]
for
x
in
batch
],
max_output_round_length
,
0.0
)
if
self
.
with_duration
:
data_dict
[
"durations"
]
=
self
.
padder
.
_prepare_durations
(
[
x
[
2
]
for
x
in
batch
],
max_dur_length
,
max_output_round_length
)
else
:
data_dict
[
"durations"
]
=
None
if
self
.
with_duration
:
if
self
.
fp_enable
:
feats_padding_length
=
max_dur_length
else
:
feats_padding_length
=
max_input_length
else
:
feats_padding_length
=
max_output_round_length
data_dict
[
"pitch_contours"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
3
]
for
x
in
batch
],
feats_padding_length
,
0.0
).
float
()
data_dict
[
"energy_contours"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
4
]
for
x
in
batch
],
feats_padding_length
,
0.0
).
float
()
if
self
.
with_duration
:
data_dict
[
"attn_priors"
]
=
None
else
:
data_dict
[
"attn_priors"
]
=
torch
.
zeros
(
len
(
batch
),
max_output_round_length
,
max_input_length
)
for
i
in
range
(
len
(
batch
)):
attn_prior
=
batch
[
i
][
5
]
data_dict
[
"attn_priors"
][
i
,
:
attn_prior
.
shape
[
0
],
:
attn_prior
.
shape
[
1
]
]
=
attn_prior
return
data_dict
# TODO: implement get_am_datasets
def
get_am_datasets
(
metafile
,
root_dir
,
config
,
allow_cache
,
split_ratio
=
0.98
,
se_enable
=
False
,
):
if
not
isinstance
(
root_dir
,
list
):
root_dir
=
[
root_dir
]
if
not
isinstance
(
metafile
,
list
):
metafile
=
[
metafile
]
train_meta_lst
=
[]
valid_meta_lst
=
[]
fp_enable
=
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"FP"
,
False
)
if
fp_enable
:
am_train_fn
=
"am_fprm_train.lst"
am_valid_fn
=
"am_fprm_valid.lst"
else
:
am_train_fn
=
"am_train.lst"
am_valid_fn
=
"am_valid.lst"
for
raw_metafile
,
data_dir
in
zip
(
metafile
,
root_dir
):
train_meta
=
os
.
path
.
join
(
data_dir
,
am_train_fn
)
valid_meta
=
os
.
path
.
join
(
data_dir
,
am_valid_fn
)
if
not
os
.
path
.
exists
(
train_meta
)
or
not
os
.
path
.
exists
(
valid_meta
):
AM_Dataset
.
gen_metafile
(
raw_metafile
,
data_dir
,
train_meta
,
valid_meta
,
split_ratio
,
se_enable
)
train_meta_lst
.
append
(
train_meta
)
valid_meta_lst
.
append
(
valid_meta
)
train_dataset
=
AM_Dataset
(
config
,
train_meta_lst
,
root_dir
,
allow_cache
)
valid_dataset
=
AM_Dataset
(
config
,
valid_meta_lst
[:
50
],
root_dir
,
allow_cache
)
return
train_dataset
,
valid_dataset
class
MaskingActor
(
object
):
def
__init__
(
self
,
mask_ratio
=
0.15
):
super
(
MaskingActor
,
self
).
__init__
()
self
.
mask_ratio
=
mask_ratio
pass
def
_get_random_mask
(
self
,
length
,
p1
=
0.15
):
mask
=
np
.
random
.
uniform
(
0
,
1
,
length
)
index
=
0
while
index
<
len
(
mask
):
if
mask
[
index
]
<
p1
:
mask
[
index
]
=
1
else
:
mask
[
index
]
=
0
index
+=
1
return
mask
def
_input_bert_masking
(
self
,
sequence_array
,
nb_symbol_category
,
mask_symbol_id
,
mask
,
p2
=
0.8
,
p3
=
0.1
,
p4
=
0.1
,
):
sequence_array_mask
=
sequence_array
.
copy
()
mask_id
=
np
.
where
(
mask
==
1
)[
0
]
mask_len
=
len
(
mask_id
)
rand
=
np
.
arange
(
mask_len
)
np
.
random
.
shuffle
(
rand
)
# [MASK]
mask_id_p2
=
mask_id
[
rand
[
0
:
int
(
math
.
floor
(
mask_len
*
p2
))]]
if
len
(
mask_id_p2
)
>
0
:
sequence_array_mask
[
mask_id_p2
]
=
mask_symbol_id
# rand
mask_id_p3
=
mask_id
[
rand
[
int
(
math
.
floor
(
mask_len
*
p2
))
:
int
(
math
.
floor
(
mask_len
*
p2
))
+
int
(
math
.
floor
(
mask_len
*
p3
))
]
]
if
len
(
mask_id_p3
)
>
0
:
sequence_array_mask
[
mask_id_p3
]
=
random
.
randint
(
0
,
nb_symbol_category
-
1
)
# ori
# do nothing
return
sequence_array_mask
class
BERT_Text_Dataset
(
torch
.
utils
.
data
.
Dataset
):
"""
provide (ling, ling_sy_masked, bert_mask) pair
"""
def
__init__
(
self
,
config
,
metafile
,
root_dir
,
allow_cache
=
False
,
):
self
.
meta
=
[]
self
.
config
=
config
if
not
isinstance
(
metafile
,
list
):
metafile
=
[
metafile
]
if
not
isinstance
(
root_dir
,
list
):
root_dir
=
[
root_dir
]
for
meta_file
,
data_dir
in
zip
(
metafile
,
root_dir
):
if
not
os
.
path
.
exists
(
meta_file
):
logging
.
error
(
"meta file not found: {}"
.
format
(
meta_file
))
raise
ValueError
(
"[BERT_Text_Dataset] meta file: {} not found"
.
format
(
meta_file
)
)
if
not
os
.
path
.
exists
(
data_dir
):
logging
.
error
(
"data dir not found: {}"
.
format
(
data_dir
))
raise
ValueError
(
"[BERT_Text_Dataset] data dir: {} not found"
.
format
(
data_dir
)
)
self
.
meta
.
extend
(
self
.
load_meta
(
meta_file
,
data_dir
))
self
.
allow_cache
=
allow_cache
self
.
ling_unit
=
KanTtsLinguisticUnit
(
config
)
self
.
padder
=
Padder
()
self
.
masking_actor
=
MaskingActor
(
self
.
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"params"
][
"mask_ratio"
]
)
if
allow_cache
:
self
.
manager
=
Manager
()
self
.
caches
=
self
.
manager
.
list
()
self
.
caches
+=
[()
for
_
in
range
(
len
(
self
.
meta
))]
def
__len__
(
self
):
return
len
(
self
.
meta
)
# TODO: implement __getitem__
def
__getitem__
(
self
,
idx
):
if
self
.
allow_cache
and
len
(
self
.
caches
[
idx
])
!=
0
:
ling_data
=
self
.
caches
[
idx
][
0
]
bert_mask
,
ling_sy_masked_data
=
self
.
bert_masking
(
ling_data
)
return
(
ling_data
,
ling_sy_masked_data
,
bert_mask
)
ling_txt
=
self
.
meta
[
idx
]
ling_data
=
self
.
ling_unit
.
encode_symbol_sequence
(
ling_txt
)
bert_mask
,
ling_sy_masked_data
=
self
.
bert_masking
(
ling_data
)
if
self
.
allow_cache
:
self
.
caches
[
idx
]
=
(
ling_data
,)
return
(
ling_data
,
ling_sy_masked_data
,
bert_mask
)
def
load_meta
(
self
,
metafile
,
data_dir
):
with
open
(
metafile
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
items
=
[]
logging
.
info
(
"Loading metafile..."
)
for
line
in
tqdm
(
lines
):
line
=
line
.
strip
()
index
,
ling_txt
=
line
.
split
(
"
\t
"
)
items
.
append
((
ling_txt
))
return
items
@
staticmethod
def
gen_metafile
(
raw_meta_file
,
out_dir
,
split_ratio
=
0.98
):
with
open
(
raw_meta_file
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
random
.
seed
(
DATASET_RANDOM_SEED
)
random
.
shuffle
(
lines
)
num_train
=
int
(
len
(
lines
)
*
split_ratio
)
-
1
with
open
(
os
.
path
.
join
(
out_dir
,
"bert_train.lst"
),
"w"
)
as
f
:
for
line
in
lines
[:
num_train
]:
f
.
write
(
line
)
with
open
(
os
.
path
.
join
(
out_dir
,
"bert_valid.lst"
),
"w"
)
as
f
:
for
line
in
lines
[
num_train
:]:
f
.
write
(
line
)
def
bert_masking
(
self
,
ling_data
):
length
=
len
(
ling_data
[
0
])
mask
=
self
.
masking_actor
.
_get_random_mask
(
length
,
p1
=
self
.
masking_actor
.
mask_ratio
)
mask
[
-
1
]
=
0
# sy_masked
sy_mask_symbol_id
=
self
.
ling_unit
.
encode_sy
([
self
.
ling_unit
.
_mask
])[
0
]
ling_sy_masked_data
=
self
.
masking_actor
.
_input_bert_masking
(
ling_data
[
0
],
self
.
ling_unit
.
get_unit_size
()[
"sy"
],
sy_mask_symbol_id
,
mask
,
p2
=
0.8
,
p3
=
0.1
,
p4
=
0.1
,
)
return
(
mask
,
ling_sy_masked_data
)
# TODO: implement collate_fn
def
collate_fn
(
self
,
batch
):
data_dict
=
{}
max_input_length
=
max
((
len
(
x
[
0
][
0
])
for
x
in
batch
))
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
0
]
targets_sy
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
0
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# sy masked
inputs_sy
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
1
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# tone
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
1
]
inputs_tone
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
1
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# syllable_flag
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
2
]
inputs_syllable_flag
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
2
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
# word_segment
lfeat_type
=
self
.
ling_unit
.
_lfeat_type_list
[
3
]
inputs_ws
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
0
][
3
]
for
x
in
batch
],
max_input_length
,
self
.
ling_unit
.
_sub_unit_pad
[
lfeat_type
],
).
long
()
data_dict
[
"input_lings"
]
=
torch
.
stack
(
[
inputs_sy
,
inputs_tone
,
inputs_syllable_flag
,
inputs_ws
],
dim
=
2
)
data_dict
[
"valid_input_lengths"
]
=
torch
.
as_tensor
(
[
len
(
x
[
0
][
0
])
-
1
for
x
in
batch
],
dtype
=
torch
.
long
)
# 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict
[
"targets"
]
=
targets_sy
data_dict
[
"bert_masks"
]
=
self
.
padder
.
_prepare_scalar_inputs
(
[
x
[
2
]
for
x
in
batch
],
max_input_length
,
0.0
)
return
data_dict
def
get_bert_text_datasets
(
metafile
,
root_dir
,
config
,
allow_cache
,
split_ratio
=
0.98
,
):
if
not
isinstance
(
root_dir
,
list
):
root_dir
=
[
root_dir
]
if
not
isinstance
(
metafile
,
list
):
metafile
=
[
metafile
]
train_meta_lst
=
[]
valid_meta_lst
=
[]
for
raw_metafile
,
data_dir
in
zip
(
metafile
,
root_dir
):
train_meta
=
os
.
path
.
join
(
data_dir
,
"bert_train.lst"
)
valid_meta
=
os
.
path
.
join
(
data_dir
,
"bert_valid.lst"
)
if
not
os
.
path
.
exists
(
train_meta
)
or
not
os
.
path
.
exists
(
valid_meta
):
BERT_Text_Dataset
.
gen_metafile
(
raw_metafile
,
data_dir
,
split_ratio
)
train_meta_lst
.
append
(
train_meta
)
valid_meta_lst
.
append
(
valid_meta
)
train_dataset
=
BERT_Text_Dataset
(
config
,
train_meta_lst
,
root_dir
,
allow_cache
)
valid_dataset
=
BERT_Text_Dataset
(
config
,
valid_meta_lst
,
root_dir
,
allow_cache
)
return
train_dataset
,
valid_dataset
kantts/models/__init__.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
from
kantts.models.hifigan.hifigan
import
(
# NOQA
Generator
,
# NOQA
MultiScaleDiscriminator
,
# NOQA
MultiPeriodDiscriminator
,
# NOQA
MultiSpecDiscriminator
,
# NOQA
)
import
kantts
import
kantts.train.scheduler
from
kantts.models.sambert.kantts_sambert
import
KanTtsSAMBERT
,
KanTtsTextsyBERT
# NOQA
from
kantts.utils.ling_unit.ling_unit
import
get_fpdict
from
.pqmf
import
PQMF
def
optimizer_builder
(
model_params
,
opt_name
,
opt_params
):
opt_cls
=
getattr
(
torch
.
optim
,
opt_name
)
optimizer
=
opt_cls
(
model_params
,
**
opt_params
)
return
optimizer
def
scheduler_builder
(
optimizer
,
sche_name
,
sche_params
):
scheduler_cls
=
getattr
(
kantts
.
train
.
scheduler
,
sche_name
)
scheduler
=
scheduler_cls
(
optimizer
,
**
sche_params
)
return
scheduler
def
hifigan_model_builder
(
config
,
device
,
rank
,
distributed
):
model
=
{}
optimizer
=
{}
scheduler
=
{}
model
[
"discriminator"
]
=
{}
optimizer
[
"discriminator"
]
=
{}
scheduler
[
"discriminator"
]
=
{}
for
model_name
in
config
[
"Model"
].
keys
():
if
model_name
==
"Generator"
:
params
=
config
[
"Model"
][
model_name
][
"params"
]
model
[
"generator"
]
=
Generator
(
**
params
).
to
(
device
)
optimizer
[
"generator"
]
=
optimizer_builder
(
model
[
"generator"
].
parameters
(),
config
[
"Model"
][
model_name
][
"optimizer"
].
get
(
"type"
,
"Adam"
),
config
[
"Model"
][
model_name
][
"optimizer"
].
get
(
"params"
,
{}),
)
scheduler
[
"generator"
]
=
scheduler_builder
(
optimizer
[
"generator"
],
config
[
"Model"
][
model_name
][
"scheduler"
].
get
(
"type"
,
"StepLR"
),
config
[
"Model"
][
model_name
][
"scheduler"
].
get
(
"params"
,
{}),
)
else
:
params
=
config
[
"Model"
][
model_name
][
"params"
]
model
[
"discriminator"
][
model_name
]
=
globals
()[
model_name
](
**
params
).
to
(
device
)
optimizer
[
"discriminator"
][
model_name
]
=
optimizer_builder
(
model
[
"discriminator"
][
model_name
].
parameters
(),
config
[
"Model"
][
model_name
][
"optimizer"
].
get
(
"type"
,
"Adam"
),
config
[
"Model"
][
model_name
][
"optimizer"
].
get
(
"params"
,
{}),
)
scheduler
[
"discriminator"
][
model_name
]
=
scheduler_builder
(
optimizer
[
"discriminator"
][
model_name
],
config
[
"Model"
][
model_name
][
"scheduler"
].
get
(
"type"
,
"StepLR"
),
config
[
"Model"
][
model_name
][
"scheduler"
].
get
(
"params"
,
{}),
)
out_channels
=
config
[
"Model"
][
"Generator"
][
"params"
][
"out_channels"
]
if
out_channels
>
1
:
model
[
"pqmf"
]
=
PQMF
(
subbands
=
out_channels
,
**
config
.
get
(
"pqmf"
,
{})).
to
(
device
)
# FIXME: pywavelets buffer leads to gradient error in DDP training
# Solution: https://github.com/pytorch/pytorch/issues/22095
if
distributed
:
model
[
"generator"
]
=
DistributedDataParallel
(
model
[
"generator"
],
device_ids
=
[
rank
],
output_device
=
rank
,
broadcast_buffers
=
False
,
)
for
model_name
in
model
[
"discriminator"
].
keys
():
model
[
"discriminator"
][
model_name
]
=
DistributedDataParallel
(
model
[
"discriminator"
][
model_name
],
device_ids
=
[
rank
],
output_device
=
rank
,
broadcast_buffers
=
False
,
)
return
model
,
optimizer
,
scheduler
# TODO: some parsing
def
sambert_model_builder
(
config
,
device
,
rank
,
distributed
):
model
=
{}
optimizer
=
{}
scheduler
=
{}
model
[
"KanTtsSAMBERT"
]
=
KanTtsSAMBERT
(
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
]
).
to
(
device
)
fp_enable
=
config
[
"Model"
][
"KanTtsSAMBERT"
][
"params"
].
get
(
"FP"
,
False
)
if
fp_enable
:
fp_dict
=
{
k
:
torch
.
from_numpy
(
v
).
long
().
unsqueeze
(
0
).
to
(
device
)
for
k
,
v
in
get_fpdict
(
config
).
items
()
}
model
[
"KanTtsSAMBERT"
].
fp_dict
=
fp_dict
optimizer
[
"KanTtsSAMBERT"
]
=
optimizer_builder
(
model
[
"KanTtsSAMBERT"
].
parameters
(),
config
[
"Model"
][
"KanTtsSAMBERT"
][
"optimizer"
].
get
(
"type"
,
"Adam"
),
config
[
"Model"
][
"KanTtsSAMBERT"
][
"optimizer"
].
get
(
"params"
,
{}),
)
scheduler
[
"KanTtsSAMBERT"
]
=
scheduler_builder
(
optimizer
[
"KanTtsSAMBERT"
],
config
[
"Model"
][
"KanTtsSAMBERT"
][
"scheduler"
].
get
(
"type"
,
"StepLR"
),
config
[
"Model"
][
"KanTtsSAMBERT"
][
"scheduler"
].
get
(
"params"
,
{}),
)
if
distributed
:
model
[
"KanTtsSAMBERT"
]
=
DistributedDataParallel
(
model
[
"KanTtsSAMBERT"
],
device_ids
=
[
rank
],
output_device
=
rank
)
return
model
,
optimizer
,
scheduler
def
sybert_model_builder
(
config
,
device
,
rank
,
distributed
):
model
=
{}
optimizer
=
{}
scheduler
=
{}
model
[
"KanTtsTextsyBERT"
]
=
KanTtsTextsyBERT
(
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"params"
]
).
to
(
device
)
optimizer
[
"KanTtsTextsyBERT"
]
=
optimizer_builder
(
model
[
"KanTtsTextsyBERT"
].
parameters
(),
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"optimizer"
].
get
(
"type"
,
"Adam"
),
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"optimizer"
].
get
(
"params"
,
{}),
)
scheduler
[
"KanTtsTextsyBERT"
]
=
scheduler_builder
(
optimizer
[
"KanTtsTextsyBERT"
],
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"scheduler"
].
get
(
"type"
,
"StepLR"
),
config
[
"Model"
][
"KanTtsTextsyBERT"
][
"scheduler"
].
get
(
"params"
,
{}),
)
if
distributed
:
model
[
"KanTtsTextsyBERT"
]
=
DistributedDataParallel
(
model
[
"KanTtsTextsyBERT"
],
device_ids
=
[
rank
],
output_device
=
rank
)
return
model
,
optimizer
,
scheduler
# TODO: implement a builder for specific model
model_dict
=
{
"hifigan"
:
hifigan_model_builder
,
"sambert"
:
sambert_model_builder
,
"sybert"
:
sybert_model_builder
,
}
def
model_builder
(
config
,
device
=
"cpu"
,
rank
=
0
,
distributed
=
False
):
builder_func
=
model_dict
[
config
[
"model_type"
]]
model
,
optimizer
,
scheduler
=
builder_func
(
config
,
device
,
rank
,
distributed
)
return
model
,
optimizer
,
scheduler
kantts/models/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/__pycache__/pqmf.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/__pycache__/utils.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/hifigan/__pycache__/hifigan.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/hifigan/__pycache__/layers.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
kantts/models/hifigan/hifigan.py
deleted
100644 → 0
View file @
8b4e9acd
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
torch.nn.utils
import
weight_norm
,
spectral_norm
from
distutils.version
import
LooseVersion
from
pytorch_wavelets
import
DWT1DForward
from
.layers
import
(
Conv1d
,
CausalConv1d
,
ConvTranspose1d
,
CausalConvTranspose1d
,
ResidualBlock
,
SourceModule
,
)
from
kantts.utils.audio_torch
import
stft
import
copy
is_pytorch_17plus
=
LooseVersion
(
torch
.
__version__
)
>=
LooseVersion
(
"1.7"
)
class
Generator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
80
,
out_channels
=
1
,
channels
=
512
,
kernel_size
=
7
,
upsample_scales
=
(
8
,
8
,
2
,
2
),
upsample_kernal_sizes
=
(
16
,
16
,
4
,
4
),
resblock_kernel_sizes
=
(
3
,
7
,
11
),
resblock_dilations
=
[(
1
,
3
,
5
),
(
1
,
3
,
5
),
(
1
,
3
,
5
)],
repeat_upsample
=
True
,
bias
=
True
,
causal
=
True
,
nonlinear_activation
=
"LeakyReLU"
,
nonlinear_activation_params
=
{
"negative_slope"
:
0.1
},
use_weight_norm
=
True
,
nsf_params
=
None
,
):
super
(
Generator
,
self
).
__init__
()
# check hyperparameters are valid
assert
kernel_size
%
2
==
1
,
"Kernal size must be odd number."
assert
len
(
upsample_scales
)
==
len
(
upsample_kernal_sizes
)
assert
len
(
resblock_dilations
)
==
len
(
resblock_kernel_sizes
)
self
.
upsample_scales
=
upsample_scales
self
.
repeat_upsample
=
repeat_upsample
self
.
num_upsamples
=
len
(
upsample_kernal_sizes
)
self
.
num_kernels
=
len
(
resblock_kernel_sizes
)
self
.
out_channels
=
out_channels
self
.
nsf_enable
=
nsf_params
is
not
None
self
.
transpose_upsamples
=
torch
.
nn
.
ModuleList
()
self
.
repeat_upsamples
=
torch
.
nn
.
ModuleList
()
# for repeat upsampling
self
.
conv_blocks
=
torch
.
nn
.
ModuleList
()
conv_cls
=
CausalConv1d
if
causal
else
Conv1d
conv_transposed_cls
=
CausalConvTranspose1d
if
causal
else
ConvTranspose1d
self
.
conv_pre
=
conv_cls
(
in_channels
,
channels
,
kernel_size
,
1
,
padding
=
(
kernel_size
-
1
)
//
2
)
for
i
in
range
(
len
(
upsample_kernal_sizes
)):
self
.
transpose_upsamples
.
append
(
torch
.
nn
.
Sequential
(
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
conv_transposed_cls
(
channels
//
(
2
**
i
),
channels
//
(
2
**
(
i
+
1
)),
upsample_kernal_sizes
[
i
],
upsample_scales
[
i
],
padding
=
(
upsample_kernal_sizes
[
i
]
-
upsample_scales
[
i
])
//
2
,
),
)
)
if
repeat_upsample
:
self
.
repeat_upsamples
.
append
(
nn
.
Sequential
(
nn
.
Upsample
(
mode
=
"nearest"
,
scale_factor
=
upsample_scales
[
i
]),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
conv_cls
(
channels
//
(
2
**
i
),
channels
//
(
2
**
(
i
+
1
)),
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
),
)
)
for
j
in
range
(
len
(
resblock_kernel_sizes
)):
self
.
conv_blocks
.
append
(
ResidualBlock
(
channels
=
channels
//
(
2
**
(
i
+
1
)),
kernel_size
=
resblock_kernel_sizes
[
j
],
dilation
=
resblock_dilations
[
j
],
nonlinear_activation
=
nonlinear_activation
,
nonlinear_activation_params
=
nonlinear_activation_params
,
causal
=
causal
,
)
)
self
.
conv_post
=
conv_cls
(
channels
//
(
2
**
(
i
+
1
)),
out_channels
,
kernel_size
,
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
if
self
.
nsf_enable
:
self
.
source_module
=
SourceModule
(
nb_harmonics
=
nsf_params
[
"nb_harmonics"
],
upsample_ratio
=
np
.
cumprod
(
self
.
upsample_scales
)[
-
1
],
sampling_rate
=
nsf_params
[
"sampling_rate"
],
)
self
.
source_downs
=
nn
.
ModuleList
()
self
.
downsample_rates
=
[
1
]
+
self
.
upsample_scales
[::
-
1
][:
-
1
]
self
.
downsample_cum_rates
=
np
.
cumprod
(
self
.
downsample_rates
)
for
i
,
u
in
enumerate
(
self
.
downsample_cum_rates
[::
-
1
]):
if
u
==
1
:
self
.
source_downs
.
append
(
Conv1d
(
1
,
channels
//
(
2
**
(
i
+
1
)),
1
,
1
)
)
else
:
self
.
source_downs
.
append
(
conv_cls
(
1
,
channels
//
(
2
**
(
i
+
1
)),
u
*
2
,
u
,
padding
=
u
//
2
,
)
)
def
forward
(
self
,
x
):
if
self
.
nsf_enable
:
mel
=
x
[:,
:
-
2
,
:]
pitch
=
x
[:,
-
2
:
-
1
,
:]
uv
=
x
[:,
-
1
:,
:]
excitation
=
self
.
source_module
(
pitch
,
uv
)
else
:
mel
=
x
x
=
self
.
conv_pre
(
mel
)
for
i
in
range
(
self
.
num_upsamples
):
# FIXME: sin function here seems to be causing issues
x
=
torch
.
sin
(
x
)
+
x
rep
=
self
.
repeat_upsamples
[
i
](
x
)
# transconv
up
=
self
.
transpose_upsamples
[
i
](
x
)
if
self
.
nsf_enable
:
# Downsampling the excitation signal
e
=
self
.
source_downs
[
i
](
excitation
)
# augment inputs with the excitation
x
=
rep
+
e
+
up
[:,
:,
:
rep
.
shape
[
-
1
]]
else
:
x
=
rep
+
up
[:,
:,
:
rep
.
shape
[
-
1
]]
xs
=
None
for
j
in
range
(
self
.
num_kernels
):
if
xs
is
None
:
xs
=
self
.
conv_blocks
[
i
*
self
.
num_kernels
+
j
](
x
)
else
:
xs
+=
self
.
conv_blocks
[
i
*
self
.
num_kernels
+
j
](
x
)
x
=
xs
/
self
.
num_kernels
x
=
F
.
leaky_relu
(
x
)
x
=
self
.
conv_post
(
x
)
x
=
torch
.
tanh
(
x
)
return
x
def
remove_weight_norm
(
self
):
print
(
"Removing weight norm..."
)
for
layer
in
self
.
transpose_upsamples
:
layer
[
-
1
].
remove_weight_norm
()
for
layer
in
self
.
repeat_upsamples
:
layer
[
-
1
].
remove_weight_norm
()
for
layer
in
self
.
conv_blocks
:
layer
.
remove_weight_norm
()
self
.
conv_pre
.
remove_weight_norm
()
self
.
conv_post
.
remove_weight_norm
()
if
self
.
nsf_enable
:
self
.
source_module
.
remove_weight_norm
()
for
layer
in
self
.
source_downs
:
layer
.
remove_weight_norm
()
class
PeriodDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
1
,
out_channels
=
1
,
period
=
3
,
kernel_sizes
=
[
5
,
3
],
channels
=
32
,
downsample_scales
=
[
3
,
3
,
3
,
3
,
1
],
max_downsample_channels
=
1024
,
bias
=
True
,
nonlinear_activation
=
"LeakyReLU"
,
nonlinear_activation_params
=
{
"negative_slope"
:
0.1
},
use_spectral_norm
=
False
,
):
super
(
PeriodDiscriminator
,
self
).
__init__
()
self
.
period
=
period
norm_f
=
weight_norm
if
not
use_spectral_norm
else
spectral_norm
self
.
convs
=
nn
.
ModuleList
()
in_chs
,
out_chs
=
in_channels
,
channels
for
downsample_scale
in
downsample_scales
:
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv2d
(
in_chs
,
out_chs
,
(
kernel_sizes
[
0
],
1
),
(
downsample_scale
,
1
),
padding
=
((
kernel_sizes
[
0
]
-
1
)
//
2
,
0
),
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
in_chs
=
out_chs
out_chs
=
min
(
out_chs
*
4
,
max_downsample_channels
)
self
.
conv_post
=
nn
.
Conv2d
(
out_chs
,
out_channels
,
(
kernel_sizes
[
1
]
-
1
,
1
),
1
,
padding
=
((
kernel_sizes
[
1
]
-
1
)
//
2
,
0
),
)
def
forward
(
self
,
x
):
fmap
=
[]
# 1d to 2d
b
,
c
,
t
=
x
.
shape
if
t
%
self
.
period
!=
0
:
# pad first
n_pad
=
self
.
period
-
(
t
%
self
.
period
)
x
=
F
.
pad
(
x
,
(
0
,
n_pad
),
"reflect"
)
t
=
t
+
n_pad
x
=
x
.
view
(
b
,
c
,
t
//
self
.
period
,
self
.
period
)
for
layer
in
self
.
convs
:
x
=
layer
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
x
=
torch
.
flatten
(
x
,
1
,
-
1
)
return
x
,
fmap
class
MultiPeriodDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
periods
=
[
2
,
3
,
5
,
7
,
11
],
discriminator_params
=
{
"in_channels"
:
1
,
"out_channels"
:
1
,
"kernel_sizes"
:
[
5
,
3
],
"channels"
:
32
,
"downsample_scales"
:
[
3
,
3
,
3
,
3
,
1
],
"max_downsample_channels"
:
1024
,
"bias"
:
True
,
"nonlinear_activation"
:
"LeakyReLU"
,
"nonlinear_activation_params"
:
{
"negative_slope"
:
0.1
},
"use_spectral_norm"
:
False
,
},
):
super
(
MultiPeriodDiscriminator
,
self
).
__init__
()
self
.
discriminators
=
nn
.
ModuleList
()
for
period
in
periods
:
params
=
copy
.
deepcopy
(
discriminator_params
)
params
[
"period"
]
=
period
self
.
discriminators
+=
[
PeriodDiscriminator
(
**
params
)]
def
forward
(
self
,
y
):
y_d_rs
=
[]
fmap_rs
=
[]
for
i
,
d
in
enumerate
(
self
.
discriminators
):
y_d_r
,
fmap_r
=
d
(
y
)
y_d_rs
.
append
(
y_d_r
)
fmap_rs
.
append
(
fmap_r
)
return
y_d_rs
,
fmap_rs
class
ScaleDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
1
,
out_channels
=
1
,
kernel_sizes
=
[
15
,
41
,
5
,
3
],
channels
=
128
,
max_downsample_channels
=
1024
,
max_groups
=
16
,
bias
=
True
,
downsample_scales
=
[
2
,
2
,
4
,
4
,
1
],
nonlinear_activation
=
"LeakyReLU"
,
nonlinear_activation_params
=
{
"negative_slope"
:
0.1
},
use_spectral_norm
=
False
,
):
super
(
ScaleDiscriminator
,
self
).
__init__
()
norm_f
=
weight_norm
if
not
use_spectral_norm
else
spectral_norm
assert
len
(
kernel_sizes
)
==
4
for
ks
in
kernel_sizes
:
assert
ks
%
2
==
1
self
.
convs
=
nn
.
ModuleList
()
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv1d
(
in_channels
,
channels
,
kernel_sizes
[
0
],
bias
=
bias
,
padding
=
(
kernel_sizes
[
0
]
-
1
)
//
2
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
in_chs
=
channels
out_chs
=
channels
groups
=
4
for
downsample_scale
in
downsample_scales
:
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv1d
(
in_chs
,
out_chs
,
kernel_size
=
kernel_sizes
[
1
],
stride
=
downsample_scale
,
padding
=
(
kernel_sizes
[
1
]
-
1
)
//
2
,
groups
=
groups
,
bias
=
bias
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
in_chs
=
out_chs
out_chs
=
min
(
in_chs
*
2
,
max_downsample_channels
)
groups
=
min
(
groups
*
4
,
max_groups
)
out_chs
=
min
(
in_chs
*
2
,
max_downsample_channels
)
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv1d
(
in_chs
,
out_chs
,
kernel_size
=
kernel_sizes
[
2
],
stride
=
1
,
padding
=
(
kernel_sizes
[
2
]
-
1
)
//
2
,
bias
=
bias
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
self
.
conv_post
=
norm_f
(
nn
.
Conv1d
(
out_chs
,
out_channels
,
kernel_size
=
kernel_sizes
[
3
],
stride
=
1
,
padding
=
(
kernel_sizes
[
3
]
-
1
)
//
2
,
bias
=
bias
,
)
)
def
forward
(
self
,
x
):
fmap
=
[]
for
layer
in
self
.
convs
:
x
=
layer
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
x
=
torch
.
flatten
(
x
,
1
,
-
1
)
return
x
,
fmap
class
MultiScaleDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
scales
=
3
,
downsample_pooling
=
"DWT"
,
# follow the official implementation setting
downsample_pooling_params
=
{
"kernel_size"
:
4
,
"stride"
:
2
,
"padding"
:
2
,
},
discriminator_params
=
{
"in_channels"
:
1
,
"out_channels"
:
1
,
"kernel_sizes"
:
[
15
,
41
,
5
,
3
],
"channels"
:
128
,
"max_downsample_channels"
:
1024
,
"max_groups"
:
16
,
"bias"
:
True
,
"downsample_scales"
:
[
2
,
2
,
4
,
4
,
1
],
"nonlinear_activation"
:
"LeakyReLU"
,
"nonlinear_activation_params"
:
{
"negative_slope"
:
0.1
},
},
follow_official_norm
=
False
,
):
super
(
MultiScaleDiscriminator
,
self
).
__init__
()
self
.
discriminators
=
torch
.
nn
.
ModuleList
()
# add discriminators
for
i
in
range
(
scales
):
params
=
copy
.
deepcopy
(
discriminator_params
)
if
follow_official_norm
:
params
[
"use_spectral_norm"
]
=
True
if
i
==
0
else
False
self
.
discriminators
+=
[
ScaleDiscriminator
(
**
params
)]
if
downsample_pooling
==
"DWT"
:
self
.
meanpools
=
nn
.
ModuleList
(
[
DWT1DForward
(
wave
=
"db3"
,
J
=
1
),
DWT1DForward
(
wave
=
"db3"
,
J
=
1
)]
)
self
.
aux_convs
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv1d
(
2
,
1
,
15
,
1
,
padding
=
7
)),
weight_norm
(
nn
.
Conv1d
(
2
,
1
,
15
,
1
,
padding
=
7
)),
]
)
else
:
self
.
meanpools
=
nn
.
ModuleList
(
[
nn
.
AvgPool1d
(
4
,
2
,
padding
=
2
),
nn
.
AvgPool1d
(
4
,
2
,
padding
=
2
)]
)
self
.
aux_convs
=
None
def
forward
(
self
,
y
):
y_d_rs
=
[]
fmap_rs
=
[]
for
i
,
d
in
enumerate
(
self
.
discriminators
):
if
i
!=
0
:
if
self
.
aux_convs
is
None
:
y
=
self
.
meanpools
[
i
-
1
](
y
)
else
:
yl
,
yh
=
self
.
meanpools
[
i
-
1
](
y
)
y
=
torch
.
cat
([
yl
,
yh
[
0
]],
dim
=
1
)
y
=
self
.
aux_convs
[
i
-
1
](
y
)
y
=
F
.
leaky_relu
(
y
,
0.1
)
y_d_r
,
fmap_r
=
d
(
y
)
y_d_rs
.
append
(
y_d_r
)
fmap_rs
.
append
(
fmap_r
)
return
y_d_rs
,
fmap_rs
class
SpecDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
channels
=
32
,
init_kernel
=
15
,
kernel_size
=
11
,
stride
=
2
,
use_spectral_norm
=
False
,
fft_size
=
1024
,
shift_size
=
120
,
win_length
=
600
,
window
=
"hann_window"
,
nonlinear_activation
=
"LeakyReLU"
,
nonlinear_activation_params
=
{
"negative_slope"
:
0.1
},
):
super
(
SpecDiscriminator
,
self
).
__init__
()
self
.
fft_size
=
fft_size
self
.
shift_size
=
shift_size
self
.
win_length
=
win_length
# fft_size // 2 + 1
norm_f
=
weight_norm
if
not
use_spectral_norm
else
spectral_norm
final_kernel
=
5
post_conv_kernel
=
3
blocks
=
3
# TODO: remove hard code here
self
.
convs
=
nn
.
ModuleList
()
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv2d
(
fft_size
//
2
+
1
,
channels
,
(
init_kernel
,
1
),
(
1
,
1
),
padding
=
(
init_kernel
-
1
)
//
2
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
for
i
in
range
(
blocks
):
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv2d
(
channels
,
channels
,
(
kernel_size
,
1
),
(
stride
,
1
),
padding
=
(
kernel_size
-
1
)
//
2
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
self
.
convs
.
append
(
torch
.
nn
.
Sequential
(
norm_f
(
nn
.
Conv2d
(
channels
,
channels
,
(
final_kernel
,
1
),
(
1
,
1
),
padding
=
(
final_kernel
-
1
)
//
2
,
)
),
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
),
)
)
self
.
conv_post
=
norm_f
(
nn
.
Conv2d
(
channels
,
1
,
(
post_conv_kernel
,
1
),
(
1
,
1
),
padding
=
((
post_conv_kernel
-
1
)
//
2
,
0
),
)
)
self
.
register_buffer
(
"window"
,
getattr
(
torch
,
window
)(
win_length
))
def
forward
(
self
,
wav
):
with
torch
.
no_grad
():
wav
=
torch
.
squeeze
(
wav
,
1
)
x_mag
=
stft
(
wav
,
self
.
fft_size
,
self
.
shift_size
,
self
.
win_length
,
self
.
window
)
x
=
torch
.
transpose
(
x_mag
,
2
,
1
).
unsqueeze
(
-
1
)
fmap
=
[]
for
layer
in
self
.
convs
:
x
=
layer
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
x
=
x
.
squeeze
(
-
1
)
return
x
,
fmap
class
MultiSpecDiscriminator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fft_sizes
=
[
1024
,
2048
,
512
],
hop_sizes
=
[
120
,
240
,
50
],
win_lengths
=
[
600
,
1200
,
240
],
discriminator_params
=
{
"channels"
:
15
,
"init_kernel"
:
1
,
"kernel_sizes"
:
11
,
"stride"
:
2
,
"use_spectral_norm"
:
False
,
"window"
:
"hann_window"
,
"nonlinear_activation"
:
"LeakyReLU"
,
"nonlinear_activation_params"
:
{
"negative_slope"
:
0.1
},
},
):
super
(
MultiSpecDiscriminator
,
self
).
__init__
()
self
.
discriminators
=
nn
.
ModuleList
()
for
fft_size
,
hop_size
,
win_length
in
zip
(
fft_sizes
,
hop_sizes
,
win_lengths
):
params
=
copy
.
deepcopy
(
discriminator_params
)
params
[
"fft_size"
]
=
fft_size
params
[
"shift_size"
]
=
hop_size
params
[
"win_length"
]
=
win_length
self
.
discriminators
+=
[
SpecDiscriminator
(
**
params
)]
def
forward
(
self
,
y
):
y_d
=
[]
fmap
=
[]
for
i
,
d
in
enumerate
(
self
.
discriminators
):
x
,
x_map
=
d
(
y
)
y_d
.
append
(
x
)
fmap
.
append
(
x_map
)
return
y_d
,
fmap
kantts/models/hifigan/layers.py
deleted
100644 → 0
View file @
8b4e9acd
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.utils
import
weight_norm
,
remove_weight_norm
from
torch.distributions.uniform
import
Uniform
from
torch.distributions.normal
import
Normal
from
kantts.models.utils
import
init_weights
def
get_padding
(
kernel_size
,
dilation
=
1
):
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
class
Conv1d
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
"zeros"
,
):
super
(
Conv1d
,
self
).
__init__
()
self
.
conv1d
=
weight_norm
(
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
padding_mode
=
padding_mode
,
)
)
self
.
conv1d
.
apply
(
init_weights
)
def
forward
(
self
,
x
):
x
=
self
.
conv1d
(
x
)
return
x
def
remove_weight_norm
(
self
):
remove_weight_norm
(
self
.
conv1d
)
class
CausalConv1d
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
"zeros"
,
):
super
(
CausalConv1d
,
self
).
__init__
()
self
.
pad
=
(
kernel_size
-
1
)
*
dilation
self
.
conv1d
=
weight_norm
(
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
0
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
padding_mode
=
padding_mode
,
)
)
self
.
conv1d
.
apply
(
init_weights
)
def
forward
(
self
,
x
):
# bdt
x
=
F
.
pad
(
x
,
(
self
.
pad
,
0
,
0
,
0
,
0
,
0
),
"constant"
)
# described starting from the last dimension and moving forward.
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
x
=
self
.
conv1d
(
x
)[:,
:,
:
x
.
size
(
2
)]
return
x
def
remove_weight_norm
(
self
):
remove_weight_norm
(
self
.
conv1d
)
class
ConvTranspose1d
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
0
,
output_padding
=
0
,
):
super
(
ConvTranspose1d
,
self
).
__init__
()
self
.
deconv
=
weight_norm
(
nn
.
ConvTranspose1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
padding
,
output_padding
=
0
,
)
)
self
.
deconv
.
apply
(
init_weights
)
def
forward
(
self
,
x
):
return
self
.
deconv
(
x
)
def
remove_weight_norm
(
self
):
remove_weight_norm
(
self
.
deconv
)
# FIXME: HACK to get shape right
class
CausalConvTranspose1d
(
torch
.
nn
.
Module
):
"""CausalConvTranspose1d module with customized initialization."""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
0
,
output_padding
=
0
,
):
"""Initialize CausalConvTranspose1d module."""
super
(
CausalConvTranspose1d
,
self
).
__init__
()
self
.
deconv
=
weight_norm
(
nn
.
ConvTranspose1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
0
,
output_padding
=
0
,
)
)
self
.
stride
=
stride
self
.
deconv
.
apply
(
init_weights
)
self
.
pad
=
kernel_size
-
stride
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return
self
.
deconv
(
x
)[:,
:,
:
-
self
.
pad
]
# return self.deconv(x)
def
remove_weight_norm
(
self
):
remove_weight_norm
(
self
.
deconv
)
class
ResidualBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
channels
,
kernel_size
=
3
,
dilation
=
(
1
,
3
,
5
),
nonlinear_activation
=
"LeakyReLU"
,
nonlinear_activation_params
=
{
"negative_slope"
:
0.1
},
causal
=
False
,
):
super
(
ResidualBlock
,
self
).
__init__
()
assert
kernel_size
%
2
==
1
,
"Kernal size must be odd number."
conv_cls
=
CausalConv1d
if
causal
else
Conv1d
self
.
convs1
=
nn
.
ModuleList
(
[
conv_cls
(
channels
,
channels
,
kernel_size
,
1
,
dilation
=
dilation
[
i
],
padding
=
get_padding
(
kernel_size
,
dilation
[
i
]),
)
for
i
in
range
(
len
(
dilation
))
]
)
self
.
convs2
=
nn
.
ModuleList
(
[
conv_cls
(
channels
,
channels
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
get_padding
(
kernel_size
,
1
),
)
for
i
in
range
(
len
(
dilation
))
]
)
self
.
activation
=
getattr
(
torch
.
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
)
def
forward
(
self
,
x
):
for
c1
,
c2
in
zip
(
self
.
convs1
,
self
.
convs2
):
xt
=
self
.
activation
(
x
)
xt
=
c1
(
xt
)
xt
=
self
.
activation
(
xt
)
xt
=
c2
(
xt
)
x
=
xt
+
x
return
x
def
remove_weight_norm
(
self
):
for
layer
in
self
.
convs1
:
layer
.
remove_weight_norm
()
for
layer
in
self
.
convs2
:
layer
.
remove_weight_norm
()
class
SourceModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
nb_harmonics
,
upsample_ratio
,
sampling_rate
,
alpha
=
0.1
,
sigma
=
0.003
):
super
(
SourceModule
,
self
).
__init__
()
self
.
nb_harmonics
=
nb_harmonics
self
.
upsample_ratio
=
upsample_ratio
self
.
sampling_rate
=
sampling_rate
self
.
alpha
=
alpha
self
.
sigma
=
sigma
self
.
ffn
=
nn
.
Sequential
(
weight_norm
(
nn
.
Conv1d
(
self
.
nb_harmonics
+
1
,
1
,
kernel_size
=
1
,
stride
=
1
)),
nn
.
Tanh
(),
)
def
forward
(
self
,
pitch
,
uv
):
"""
:param pitch: [B, 1, frame_len], Hz
:param uv: [B, 1, frame_len] vuv flag
:return: [B, 1, sample_len]
"""
with
torch
.
no_grad
():
pitch_samples
=
F
.
interpolate
(
pitch
,
scale_factor
=
(
self
.
upsample_ratio
),
mode
=
"nearest"
)
uv_samples
=
F
.
interpolate
(
uv
,
scale_factor
=
(
self
.
upsample_ratio
),
mode
=
"nearest"
)
F_mat
=
torch
.
zeros
(
(
pitch_samples
.
size
(
0
),
self
.
nb_harmonics
+
1
,
pitch_samples
.
size
(
-
1
))
).
to
(
pitch_samples
.
device
)
for
i
in
range
(
self
.
nb_harmonics
+
1
):
F_mat
[:,
i
:
i
+
1
,
:]
=
pitch_samples
*
(
i
+
1
)
/
self
.
sampling_rate
theta_mat
=
2
*
np
.
pi
*
(
torch
.
cumsum
(
F_mat
,
dim
=-
1
)
%
1
)
u_dist
=
Uniform
(
low
=-
np
.
pi
,
high
=
np
.
pi
)
phase_vec
=
u_dist
.
sample
(
sample_shape
=
(
pitch
.
size
(
0
),
self
.
nb_harmonics
+
1
,
1
)
).
to
(
F_mat
.
device
)
phase_vec
[:,
0
,
:]
=
0
n_dist
=
Normal
(
loc
=
0.0
,
scale
=
self
.
sigma
)
noise
=
n_dist
.
sample
(
sample_shape
=
(
pitch_samples
.
size
(
0
),
self
.
nb_harmonics
+
1
,
pitch_samples
.
size
(
-
1
),
)
).
to
(
F_mat
.
device
)
e_voice
=
self
.
alpha
*
torch
.
sin
(
theta_mat
+
phase_vec
)
+
noise
e_unvoice
=
self
.
alpha
/
3
/
self
.
sigma
*
noise
e
=
e_voice
*
uv_samples
+
e_unvoice
*
(
1
-
uv_samples
)
return
self
.
ffn
(
e
)
def
remove_weight_norm
(
self
):
remove_weight_norm
(
self
.
ffn
[
0
])
kantts/models/pqmf.py
deleted
100644 → 0
View file @
8b4e9acd
# Copyright 2020 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Pseudo QMF modules."""
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
scipy.signal
import
kaiser
def
design_prototype_filter
(
taps
=
62
,
cutoff_ratio
=
0.142
,
beta
=
9.0
):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert
taps
%
2
==
0
,
"The number of taps mush be even number."
assert
0.0
<
cutoff_ratio
<
1.0
,
"Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c
=
np
.
pi
*
cutoff_ratio
with
np
.
errstate
(
invalid
=
"ignore"
):
h_i
=
np
.
sin
(
omega_c
*
(
np
.
arange
(
taps
+
1
)
-
0.5
*
taps
))
/
(
np
.
pi
*
(
np
.
arange
(
taps
+
1
)
-
0.5
*
taps
)
)
h_i
[
taps
//
2
]
=
np
.
cos
(
0
)
*
cutoff_ratio
# fix nan due to indeterminate form
# apply kaiser window
w
=
kaiser
(
taps
+
1
,
beta
)
h
=
h_i
*
w
return
h
class
PQMF
(
torch
.
nn
.
Module
):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def
__init__
(
self
,
subbands
=
4
,
taps
=
62
,
cutoff_ratio
=
0.142
,
beta
=
9.0
):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super
(
PQMF
,
self
).
__init__
()
# build analysis & synthesis filter coefficients
h_proto
=
design_prototype_filter
(
taps
,
cutoff_ratio
,
beta
)
h_analysis
=
np
.
zeros
((
subbands
,
len
(
h_proto
)))
h_synthesis
=
np
.
zeros
((
subbands
,
len
(
h_proto
)))
for
k
in
range
(
subbands
):
h_analysis
[
k
]
=
(
2
*
h_proto
*
np
.
cos
(
(
2
*
k
+
1
)
*
(
np
.
pi
/
(
2
*
subbands
))
*
(
np
.
arange
(
taps
+
1
)
-
(
taps
/
2
))
+
(
-
1
)
**
k
*
np
.
pi
/
4
)
)
h_synthesis
[
k
]
=
(
2
*
h_proto
*
np
.
cos
(
(
2
*
k
+
1
)
*
(
np
.
pi
/
(
2
*
subbands
))
*
(
np
.
arange
(
taps
+
1
)
-
(
taps
/
2
))
-
(
-
1
)
**
k
*
np
.
pi
/
4
)
)
# convert to tensor
analysis_filter
=
torch
.
from_numpy
(
h_analysis
).
float
().
unsqueeze
(
1
)
synthesis_filter
=
torch
.
from_numpy
(
h_synthesis
).
float
().
unsqueeze
(
0
)
# register coefficients as beffer
self
.
register_buffer
(
"analysis_filter"
,
analysis_filter
)
self
.
register_buffer
(
"synthesis_filter"
,
synthesis_filter
)
# filter for downsampling & upsampling
updown_filter
=
torch
.
zeros
((
subbands
,
subbands
,
subbands
)).
float
()
for
k
in
range
(
subbands
):
updown_filter
[
k
,
k
,
0
]
=
1.0
self
.
register_buffer
(
"updown_filter"
,
updown_filter
)
self
.
subbands
=
subbands
# keep padding info
self
.
pad_fn
=
torch
.
nn
.
ConstantPad1d
(
taps
//
2
,
0.0
)
def
analysis
(
self
,
x
):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x
=
F
.
conv1d
(
self
.
pad_fn
(
x
),
self
.
analysis_filter
)
return
F
.
conv1d
(
x
,
self
.
updown_filter
,
stride
=
self
.
subbands
)
def
synthesis
(
self
,
x
):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
# TODO(kan-bayashi): Understand the reconstruction procedure
x
=
F
.
conv_transpose1d
(
x
,
self
.
updown_filter
*
self
.
subbands
,
stride
=
self
.
subbands
)
return
F
.
conv1d
(
self
.
pad_fn
(
x
),
self
.
synthesis_filter
)
kantts/models/sambert/__init__.py
deleted
100644 → 0
View file @
8b4e9acd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
class
ScaledDotProductAttention
(
nn
.
Module
):
""" Scaled Dot-Product Attention """
def
__init__
(
self
,
temperature
,
dropatt
=
0.0
):
super
().
__init__
()
self
.
temperature
=
temperature
self
.
softmax
=
nn
.
Softmax
(
dim
=
2
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
def
forward
(
self
,
q
,
k
,
v
,
mask
=
None
):
attn
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn
=
attn
/
self
.
temperature
if
mask
is
not
None
:
attn
=
attn
.
masked_fill
(
mask
,
-
np
.
inf
)
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
dropatt
(
attn
)
output
=
torch
.
bmm
(
attn
,
v
)
return
output
,
attn
class
Prenet
(
nn
.
Module
):
def
__init__
(
self
,
in_units
,
prenet_units
,
out_units
=
0
):
super
(
Prenet
,
self
).
__init__
()
self
.
fcs
=
nn
.
ModuleList
()
for
in_dim
,
out_dim
in
zip
([
in_units
]
+
prenet_units
[:
-
1
],
prenet_units
):
self
.
fcs
.
append
(
nn
.
Linear
(
in_dim
,
out_dim
))
self
.
fcs
.
append
(
nn
.
ReLU
())
self
.
fcs
.
append
(
nn
.
Dropout
(
0.5
))
if
out_units
:
self
.
fcs
.
append
(
nn
.
Linear
(
prenet_units
[
-
1
],
out_units
))
def
forward
(
self
,
input
):
output
=
input
for
layer
in
self
.
fcs
:
output
=
layer
(
output
)
return
output
class
MultiHeadSelfAttention
(
nn
.
Module
):
""" Multi-Head SelfAttention module """
def
__init__
(
self
,
n_head
,
d_in
,
d_model
,
d_head
,
dropout
,
dropatt
=
0.0
):
super
().
__init__
()
self
.
n_head
=
n_head
self
.
d_head
=
d_head
self
.
d_in
=
d_in
self
.
d_model
=
d_model
self
.
layer_norm
=
nn
.
LayerNorm
(
d_in
,
eps
=
1e-6
)
self
.
w_qkv
=
nn
.
Linear
(
d_in
,
3
*
n_head
*
d_head
)
self
.
attention
=
ScaledDotProductAttention
(
temperature
=
np
.
power
(
d_head
,
0.5
),
dropatt
=
dropatt
)
self
.
fc
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
input
,
mask
=
None
):
d_head
,
n_head
=
self
.
d_head
,
self
.
n_head
sz_b
,
len_in
,
_
=
input
.
size
()
residual
=
input
x
=
self
.
layer_norm
(
input
)
qkv
=
self
.
w_qkv
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
-
1
)
q
=
q
.
view
(
sz_b
,
len_in
,
n_head
,
d_head
)
k
=
k
.
view
(
sz_b
,
len_in
,
n_head
,
d_head
)
v
=
v
.
view
(
sz_b
,
len_in
,
n_head
,
d_head
)
q
=
q
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_in
,
d_head
)
# (n*b) x l x d
k
=
k
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_in
,
d_head
)
# (n*b) x l x d
v
=
v
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_in
,
d_head
)
# (n*b) x l x d
if
mask
is
not
None
:
mask
=
mask
.
repeat
(
n_head
,
1
,
1
)
# (n*b) x .. x ..
output
,
attn
=
self
.
attention
(
q
,
k
,
v
,
mask
=
mask
)
output
=
output
.
view
(
n_head
,
sz_b
,
len_in
,
d_head
)
output
=
(
output
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
sz_b
,
len_in
,
-
1
)
)
# b x l x (n*d)
output
=
self
.
dropout
(
self
.
fc
(
output
))
if
output
.
size
(
-
1
)
==
residual
.
size
(
-
1
):
output
=
output
+
residual
return
output
,
attn
class
PositionwiseConvFeedForward
(
nn
.
Module
):
""" A two-feed-forward-layer module """
def
__init__
(
self
,
d_in
,
d_hid
,
kernel_size
=
(
3
,
1
),
dropout_inner
=
0.1
,
dropout
=
0.1
):
super
().
__init__
()
# Use Conv1D
# position-wise
self
.
w_1
=
nn
.
Conv1d
(
d_in
,
d_hid
,
kernel_size
=
kernel_size
[
0
],
padding
=
(
kernel_size
[
0
]
-
1
)
//
2
,
)
# position-wise
self
.
w_2
=
nn
.
Conv1d
(
d_hid
,
d_in
,
kernel_size
=
kernel_size
[
1
],
padding
=
(
kernel_size
[
1
]
-
1
)
//
2
,
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_in
,
eps
=
1e-6
)
self
.
dropout_inner
=
nn
.
Dropout
(
dropout_inner
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
mask
=
None
):
residual
=
x
x
=
self
.
layer_norm
(
x
)
output
=
x
.
transpose
(
1
,
2
)
output
=
F
.
relu
(
self
.
w_1
(
output
))
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
1
),
0
)
output
=
self
.
dropout_inner
(
output
)
output
=
self
.
w_2
(
output
)
output
=
output
.
transpose
(
1
,
2
)
output
=
self
.
dropout
(
output
)
output
=
output
+
residual
return
output
class
FFTBlock
(
nn
.
Module
):
"""FFT Block"""
def
__init__
(
self
,
d_in
,
d_model
,
n_head
,
d_head
,
d_inner
,
kernel_size
,
dropout
,
dropout_attn
=
0.0
,
dropout_relu
=
0.0
,
):
super
(
FFTBlock
,
self
).
__init__
()
self
.
slf_attn
=
MultiHeadSelfAttention
(
n_head
,
d_in
,
d_model
,
d_head
,
dropout
=
dropout
,
dropatt
=
dropout_attn
)
self
.
pos_ffn
=
PositionwiseConvFeedForward
(
d_model
,
d_inner
,
kernel_size
,
dropout_inner
=
dropout_relu
,
dropout
=
dropout
)
def
forward
(
self
,
input
,
mask
=
None
,
slf_attn_mask
=
None
):
output
,
slf_attn
=
self
.
slf_attn
(
input
,
mask
=
slf_attn_mask
)
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
output
=
self
.
pos_ffn
(
output
,
mask
=
mask
)
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
return
output
,
slf_attn
class
MultiHeadPNCAAttention
(
nn
.
Module
):
""" Multi-Head Attention PNCA module """
def
__init__
(
self
,
n_head
,
d_model
,
d_mem
,
d_head
,
dropout
,
dropatt
=
0.0
):
super
().
__init__
()
self
.
n_head
=
n_head
self
.
d_head
=
d_head
self
.
d_model
=
d_model
self
.
d_mem
=
d_mem
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
w_x_qkv
=
nn
.
Linear
(
d_model
,
3
*
n_head
*
d_head
)
self
.
fc_x
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
)
self
.
w_h_kv
=
nn
.
Linear
(
d_mem
,
2
*
n_head
*
d_head
)
self
.
fc_h
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
)
self
.
attention
=
ScaledDotProductAttention
(
temperature
=
np
.
power
(
d_head
,
0.5
),
dropatt
=
dropatt
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
update_x_state
(
self
,
x
):
d_head
,
n_head
=
self
.
d_head
,
self
.
n_head
sz_b
,
len_x
,
_
=
x
.
size
()
x_qkv
=
self
.
w_x_qkv
(
x
)
x_q
,
x_k
,
x_v
=
x_qkv
.
chunk
(
3
,
-
1
)
x_q
=
x_q
.
view
(
sz_b
,
len_x
,
n_head
,
d_head
)
x_k
=
x_k
.
view
(
sz_b
,
len_x
,
n_head
,
d_head
)
x_v
=
x_v
.
view
(
sz_b
,
len_x
,
n_head
,
d_head
)
x_q
=
x_q
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_x
,
d_head
)
x_k
=
x_k
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_x
,
d_head
)
x_v
=
x_v
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_x
,
d_head
)
if
self
.
x_state_size
:
self
.
x_k
=
torch
.
cat
([
self
.
x_k
,
x_k
],
dim
=
1
)
self
.
x_v
=
torch
.
cat
([
self
.
x_v
,
x_v
],
dim
=
1
)
else
:
self
.
x_k
=
x_k
self
.
x_v
=
x_v
self
.
x_state_size
+=
len_x
return
x_q
,
x_k
,
x_v
def
update_h_state
(
self
,
h
):
if
self
.
h_state_size
==
h
.
size
(
1
):
return
None
,
None
d_head
,
n_head
=
self
.
d_head
,
self
.
n_head
# H
sz_b
,
len_h
,
_
=
h
.
size
()
h_kv
=
self
.
w_h_kv
(
h
)
h_k
,
h_v
=
h_kv
.
chunk
(
2
,
-
1
)
h_k
=
h_k
.
view
(
sz_b
,
len_h
,
n_head
,
d_head
)
h_v
=
h_v
.
view
(
sz_b
,
len_h
,
n_head
,
d_head
)
self
.
h_k
=
h_k
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_h
,
d_head
)
self
.
h_v
=
h_v
.
permute
(
2
,
0
,
1
,
3
).
contiguous
().
view
(
-
1
,
len_h
,
d_head
)
self
.
h_state_size
+=
len_h
return
h_k
,
h_v
def
reset_state
(
self
):
self
.
h_k
=
None
self
.
h_v
=
None
self
.
h_state_size
=
0
self
.
x_k
=
None
self
.
x_v
=
None
self
.
x_state_size
=
0
def
forward
(
self
,
x
,
h
,
mask_x
=
None
,
mask_h
=
None
):
residual
=
x
self
.
update_h_state
(
h
)
x_q
,
x_k
,
x_v
=
self
.
update_x_state
(
self
.
layer_norm
(
x
))
d_head
,
n_head
=
self
.
d_head
,
self
.
n_head
sz_b
,
len_in
,
_
=
x
.
size
()
# X
if
mask_x
is
not
None
:
mask_x
=
mask_x
.
repeat
(
n_head
,
1
,
1
)
# (n*b) x .. x ..
output_x
,
attn_x
=
self
.
attention
(
x_q
,
self
.
x_k
,
self
.
x_v
,
mask
=
mask_x
)
output_x
=
output_x
.
view
(
n_head
,
sz_b
,
len_in
,
d_head
)
output_x
=
(
output_x
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
sz_b
,
len_in
,
-
1
)
)
# b x l x (n*d)
output_x
=
self
.
fc_x
(
output_x
)
# H
if
mask_h
is
not
None
:
mask_h
=
mask_h
.
repeat
(
n_head
,
1
,
1
)
output_h
,
attn_h
=
self
.
attention
(
x_q
,
self
.
h_k
,
self
.
h_v
,
mask
=
mask_h
)
output_h
=
output_h
.
view
(
n_head
,
sz_b
,
len_in
,
d_head
)
output_h
=
(
output_h
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
sz_b
,
len_in
,
-
1
)
)
# b x l x (n*d)
output_h
=
self
.
fc_h
(
output_h
)
output
=
output_x
+
output_h
output
=
self
.
dropout
(
output
)
output
=
output
+
residual
return
output
,
attn_x
,
attn_h
class
PNCABlock
(
nn
.
Module
):
"""PNCA Block"""
def
__init__
(
self
,
d_model
,
d_mem
,
n_head
,
d_head
,
d_inner
,
kernel_size
,
dropout
,
dropout_attn
=
0.0
,
dropout_relu
=
0.0
,
):
super
(
PNCABlock
,
self
).
__init__
()
self
.
pnca_attn
=
MultiHeadPNCAAttention
(
n_head
,
d_model
,
d_mem
,
d_head
,
dropout
=
dropout
,
dropatt
=
dropout_attn
)
self
.
pos_ffn
=
PositionwiseConvFeedForward
(
d_model
,
d_inner
,
kernel_size
,
dropout_inner
=
dropout_relu
,
dropout
=
dropout
)
def
forward
(
self
,
input
,
memory
,
mask
=
None
,
pnca_x_attn_mask
=
None
,
pnca_h_attn_mask
=
None
):
output
,
pnca_attn_x
,
pnca_attn_h
=
self
.
pnca_attn
(
input
,
memory
,
pnca_x_attn_mask
,
pnca_h_attn_mask
)
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
output
=
self
.
pos_ffn
(
output
,
mask
=
mask
)
if
mask
is
not
None
:
output
=
output
.
masked_fill
(
mask
.
unsqueeze
(
-
1
),
0
)
return
output
,
pnca_attn_x
,
pnca_attn_h
def
reset_state
(
self
):
self
.
pnca_attn
.
reset_state
()
kantts/models/sambert/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
8b4e9acd
File deleted
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment