Commit ee10550a authored by liugh5's avatar liugh5
Browse files

Initial commit

parents
Pipeline #790 canceled with stages
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
AttentionCTCLoss:
enable: True
AttentionBinarizationLoss:
enable: True
params:
start_epoch: 0
warmup_epoch: 100
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: True
using_byte: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: byte_index,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
AttentionCTCLoss:
enable: True
AttentionBinarizationLoss:
enable: True
params:
start_epoch: 0
warmup_epoch: 100
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 8
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 900
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 128
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
FP: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7,F74,M7,FBYN,FRXL,xiaoyu
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
FpCELoss:
enable: True
params:
loss_type: ce
weight: [1,4,4,8]
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 16
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7,F74,FBYN,FRXL,M7,xiaoyu
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 10000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 2300500 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 192
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
nsf_norm_type: global
nsf_f0_global_minimum: 30.0
nsf_f0_global_maximum: 730.0
SE: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: False
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1760101 # Number of training steps.
save_interval_steps: 100 # Interval steps to save checkpoint.
eval_interval_steps: 1000000000000 # Interval steps to evaluate the network.
log_interval_steps: 10 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 192
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
nsf_norm_type: global
nsf_f0_global_minimum: 30.0
nsf_f0_global_maximum: 730.0
SE: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: False
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 2500000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 1000000000000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: xiaoyue
language: Sichuan
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sybert
Model:
#########################################################
# TextsyBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsTextsyBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
mask_ratio: 0.3
optimizer:
type: Adam
params:
lr: 0.0001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 10000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
SeqCELoss:
enable: True
params:
loss_type: ce
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
import numpy as np
from scipy.io import wavfile
# TODO: add your own data type here as you need.
DATA_TYPE_DICT = {
"txt": {
"load_func": np.loadtxt,
"desc": "plain txt file or readable by np.loadtxt",
},
"wav": {
"load_func": lambda x: wavfile.read(x)[1],
"desc": "wav file or readable by soundfile.read",
},
"npy": {
"load_func": np.load,
"desc": "any .npy format file",
},
# PCM data type can be loaded by binary format
"bin_f32": {
"load_func": lambda x: np.fromfile(x, dtype=np.float32),
"desc": "binary file with float32 format",
},
"bin_f64": {
"load_func": lambda x: np.fromfile(x, dtype=np.float64),
"desc": "binary file with float64 format",
},
"bin_i32": {
"load_func": lambda x: np.fromfile(x, dtype=np.int32),
"desc": "binary file with int32 format",
},
"bin_i16": {
"load_func": lambda x: np.fromfile(x, dtype=np.int16),
"desc": "binary file with int16 format",
},
}
import os
import torch
import glob
import logging
from multiprocessing import Manager
import librosa
import numpy as np
import random
import functools
from tqdm import tqdm
import math
from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit, emotion_types
from scipy.stats import betabinom
DATASET_RANDOM_SEED = 1234
torch.multiprocessing.set_sharing_strategy("file_system")
@functools.lru_cache(maxsize=256)
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
P = phoneme_count
M = mel_count
x = np.arange(0, P)
mel_text_probs = []
for i in range(1, M + 1):
a, b = scaling * i, scaling * (M + 1 - i)
rv = betabinom(P, a, b)
mel_i_prob = rv.pmf(x)
mel_text_probs.append(mel_i_prob)
return torch.tensor(np.array(mel_text_probs))
class Padder(object):
def __init__(self):
super(Padder, self).__init__()
pass
def _pad1D(self, x, length, pad):
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad)
def _pad2D(self, x, length, pad):
return np.pad(
x, [(0, length - x.shape[0]), (0, 0)], mode="constant", constant_values=pad
)
def _pad_durations(self, duration, max_in_len, max_out_len):
framenum = np.sum(duration)
symbolnum = duration.shape[0]
if framenum < max_out_len:
padframenum = max_out_len - framenum
duration = np.insert(duration, symbolnum, values=padframenum, axis=0)
duration = np.insert(
duration,
symbolnum + 1,
values=[0] * (max_in_len - symbolnum - 1),
axis=0,
)
else:
if symbolnum < max_in_len:
duration = np.insert(
duration, symbolnum, values=[0] * (max_in_len - symbolnum), axis=0
)
return duration
def _round_up(self, x, multiple):
remainder = x % multiple
return x if remainder == 0 else x + multiple - remainder
def _prepare_scalar_inputs(self, inputs, max_len, pad):
return torch.from_numpy(
np.stack([self._pad1D(x, max_len, pad) for x in inputs])
)
def _prepare_targets(self, targets, max_len, pad):
return torch.from_numpy(
np.stack([self._pad2D(t, max_len, pad) for t in targets])
).float()
def _prepare_durations(self, durations, max_in_len, max_out_len):
return torch.from_numpy(
np.stack(
[self._pad_durations(t, max_in_len, max_out_len) for t in durations]
)
).long()
class Voc_Dataset(torch.utils.data.Dataset):
"""
provide (mel, audio) data pair
"""
def __init__(
self,
metafile,
root_dir,
config,
):
self.meta = []
self.config = config
self.sampling_rate = config["audio_config"]["sampling_rate"]
self.n_fft = config["audio_config"]["n_fft"]
self.hop_length = config["audio_config"]["hop_length"]
self.batch_max_steps = config["batch_max_steps"]
self.batch_max_frames = self.batch_max_steps // self.hop_length
self.aux_context_window = 0 # TODO: make it configurable
self.start_offset = self.aux_context_window
self.end_offset = -(self.batch_max_frames + self.aux_context_window)
self.nsf_enable = (
config["Model"]["Generator"]["params"].get("nsf_params", None) is not None
)
if self.nsf_enable:
self.nsf_norm_type = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_norm_type", '"mean_std')
if self.nsf_norm_type == "global":
self.nsf_f0_global_minimum = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_f0_global_minimum", 30.0)
self.nsf_f0_global_maximum = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_f0_global_maximum", 730.0)
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[Voc_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data directory not found: {}".format(data_dir))
raise ValueError(
"[Voc_Dataset] data dir: {} not found".format(data_dir)
)
self.meta.extend(self.load_meta(meta_file, data_dir))
# Load from training data directory
if len(self.meta) == 0 and isinstance(root_dir, str):
wav_dir = os.path.join(root_dir, "wav")
mel_dir = os.path.join(root_dir, "mel")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
elif len(self.meta) == 0 and isinstance(root_dir, list):
for d in root_dir:
wav_dir = os.path.join(d, "wav")
mel_dir = os.path.join(d, "mel")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
self.allow_cache = config["allow_cache"]
if self.allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
@staticmethod
def gen_metafile(wav_dir, out_dir, split_ratio=0.98):
wav_files = glob.glob(os.path.join(wav_dir, "*.wav"))
frame_f0_dir = os.path.join(out_dir, "frame_f0")
frame_uv_dir = os.path.join(out_dir, "frame_uv")
mel_dir = os.path.join(out_dir, "mel")
random.seed(DATASET_RANDOM_SEED)
random.shuffle(wav_files)
num_train = int(len(wav_files) * split_ratio) - 1
with open(os.path.join(out_dir, "train.lst"), "w") as f:
for wav_file in wav_files[:num_train]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
f.write("{}\n".format(index))
with open(os.path.join(out_dir, "valid.lst"), "w") as f:
for wav_file in wav_files[num_train:]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
f.write("{}\n".format(index))
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
wav_dir = os.path.join(data_dir, "wav")
mel_dir = os.path.join(data_dir, "mel")
frame_f0_dir = os.path.join(data_dir, "frame_f0")
frame_uv_dir = os.path.join(data_dir, "frame_uv")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
items = []
logging.info("Loading metafile...")
for name in tqdm(lines):
name = name.strip()
mel_file = os.path.join(mel_dir, name + ".npy")
wav_file = os.path.join(wav_dir, name + ".wav")
frame_f0_file = os.path.join(frame_f0_dir, name + ".npy")
frame_uv_file = os.path.join(frame_uv_dir, name + ".npy")
items.append((wav_file, mel_file, frame_f0_file, frame_uv_file))
return items
def load_meta_from_dir(self, wav_dir, mel_dir):
wav_files = glob.glob(os.path.join(wav_dir, "*.wav"))
items = []
for wav_file in wav_files:
mel_file = os.path.join(mel_dir, os.path.basename(wav_file))
if os.path.exists(mel_file):
items.append((wav_file, mel_file))
return items
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx]
f0_mean_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt"
)
f0_std_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt"
)
wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0]
mel_data = np.load(mel_file)
if self.nsf_enable:
# denorm f0; default frame_f0_data using mean_std norm
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
f0_mean = np.loadtxt(f0_mean_file)
f0_std = np.loadtxt(f0_std_file)
frame_f0_data = frame_f0_data * f0_std + f0_mean
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data), axis=1)
# make sure mel_data length greater than batch_max_frames at least 1 frame
if mel_data.shape[0] <= self.batch_max_frames:
mel_data = np.concatenate(
(
mel_data,
np.zeros(
(
self.batch_max_frames - mel_data.shape[0] + 1,
mel_data.shape[1],
)
),
),
axis=0,
)
wav_cache = np.zeros(mel_data.shape[0] * self.hop_length, dtype=np.float32)
wav_cache[: len(wav_data)] = wav_data
wav_data = wav_cache
else:
# make sure the audio length and feature length are matched
wav_data = np.pad(wav_data, (0, self.n_fft), mode="reflect")
wav_data = wav_data[: len(mel_data) * self.hop_length]
assert len(mel_data) * self.hop_length == len(wav_data)
if self.allow_cache:
self.caches[idx] = (wav_data, mel_data)
return (wav_data, mel_data)
def collate_fn(self, batch):
wav_data, mel_data = [item[0] for item in batch], [item[1] for item in batch]
mel_lengths = [len(mel) for mel in mel_data]
start_frames = np.array(
[
np.random.randint(self.start_offset, length + self.end_offset)
for length in mel_lengths
]
)
wav_start = start_frames * self.hop_length
wav_end = wav_start + self.batch_max_steps
# aux window works as padding
mel_start = start_frames - self.aux_context_window
mel_end = mel_start + self.batch_max_frames + self.aux_context_window
wav_batch = [
x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end)
]
mel_batch = [
c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end)
]
# (B, 1, T)
wav_batch = torch.tensor(np.asarray(wav_batch), dtype=torch.float32).unsqueeze(
1
)
# (B, C, T)
mel_batch = torch.tensor(np.asarray(mel_batch), dtype=torch.float32).transpose(
2, 1
)
return wav_batch, mel_batch
def get_voc_datasets(
config,
root_dir,
split_ratio=0.98,
):
if isinstance(root_dir, str):
root_dir = [root_dir]
train_meta_lst = []
valid_meta_lst = []
for data_dir in root_dir:
train_meta = os.path.join(data_dir, "train.lst")
valid_meta = os.path.join(data_dir, "valid.lst")
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
Voc_Dataset.gen_metafile(
os.path.join(data_dir, "wav"), data_dir, split_ratio
)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = Voc_Dataset(
train_meta_lst,
root_dir,
config,
)
valid_dataset = Voc_Dataset(
valid_meta_lst[:50],
root_dir,
config,
)
return train_dataset, valid_dataset
# TODO(Yuxuan): refine the logic, you'd better not use emotion tag, it's ambiguous.
def get_fp_label(aug_ling_txt):
token_lst = aug_ling_txt.split(" ")
emo_lst = [token.strip("{}").split("$")[4] for token in token_lst]
syllable_lst = [token.strip("{}").split("$")[0] for token in token_lst]
# EOS token append
emo_lst.append(emotion_types[0])
syllable_lst.append("EOS")
# According to the original emotion tag, set each token's fp label.
if emo_lst[0] != emotion_types[3]:
emo_lst[0] = emotion_types[0]
emo_lst[1] = emotion_types[0]
for i in range(len(emo_lst) - 2, 1, -1):
if emo_lst[i] != emotion_types[3] and emo_lst[i - 1] != emotion_types[3]:
emo_lst[i] = emotion_types[0]
elif emo_lst[i] != emotion_types[3] and emo_lst[i - 1] == emotion_types[3]:
emo_lst[i] = emotion_types[3]
if syllable_lst[i - 2] == "ga":
emo_lst[i + 1] = emotion_types[1]
elif syllable_lst[i - 2] == "ge" and syllable_lst[i - 1] == "en_c":
emo_lst[i + 1] = emotion_types[2]
else:
emo_lst[i + 1] = emotion_types[4]
fp_label = []
for i in range(len(emo_lst)):
if emo_lst[i] == emotion_types[0]:
fp_label.append(0)
elif emo_lst[i] == emotion_types[1]:
fp_label.append(1)
elif emo_lst[i] == emotion_types[2]:
fp_label.append(2)
elif emo_lst[i] == emotion_types[3]:
continue
elif emo_lst[i] == emotion_types[4]:
fp_label.append(3)
else:
pass
return np.array(fp_label)
class AM_Dataset(torch.utils.data.Dataset):
"""
provide (ling, emo, speaker, mel) pair
"""
def __init__(
self,
config,
metafile,
root_dir,
allow_cache=False,
):
self.meta = []
self.config = config
self.with_duration = True
self.nsf_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"NSF", False
)
if self.nsf_enable:
self.nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get(
"nsf_norm_type", "mean_std"
)
if self.nsf_norm_type == "global":
self.nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"][
"params"
].get("nsf_f0_global_minimum", 30.0)
self.nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"][
"params"
].get("nsf_f0_global_maximum", 730.0)
self.se_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"SE", False
)
self.fp_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"FP", False
)
self.mas_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"MAS", False
)
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[AM_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data dir not found: {}".format(data_dir))
raise ValueError("[AM_Dataset] data dir: {} not found".format(data_dir))
self.meta.extend(self.load_meta(meta_file, data_dir))
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config)
self.padder = Padder()
self.r = self.config["Model"]["KanTtsSAMBERT"]["params"]["outputs_per_step"]
# TODO: feat window
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
(
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
se_path,
) = self.meta[idx]
f0_mean_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt"
)
f0_std_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt"
)
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
mel_data = np.load(mel_file)
dur_data = np.load(dur_file) if dur_file is not None else None
f0_data = np.load(f0_file)
energy_data = np.load(energy_file)
se_data = np.load(se_path) if self.se_enable else None
# generate fp position label according to fpadd_meta
if self.fp_enable and aug_ling_txt is not None:
fp_label = get_fp_label(aug_ling_txt)
else:
fp_label = None
if self.with_duration:
attn_prior = None
else:
attn_prior = beta_binomial_prior_distribution(
len(ling_data[0]), mel_data.shape[0]
)
# Concat frame-level f0 and uv to mel_data
if self.nsf_enable:
# origin f0 data is mean std normed
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
# default f0 data is mean std normed; re-norm here
if self.nsf_norm_type == "global":
# denorm f0
f0_mean = np.loadtxt(f0_mean_file)
f0_std = np.loadtxt(f0_std_file)
f0_origin = frame_f0_data * f0_std + f0_mean
# renorm f0
frame_f0_data = (f0_origin - self.nsf_f0_global_minimum) / (
self.nsf_f0_global_maximum - self.nsf_f0_global_minimum
)
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data], axis=1)
if self.allow_cache:
self.caches[idx] = (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
se_data,
)
return (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
se_data,
)
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
aug_ling_dict = {}
if self.fp_enable:
add_fp_metafile = metafile.replace("fprm", "fpadd")
with open(add_fp_metafile, "r") as f:
fpadd_lines = f.readlines()
for line in fpadd_lines:
index, aug_ling_txt = line.split("\t")
aug_ling_dict[index] = aug_ling_txt
mel_dir = os.path.join(data_dir, "mel")
dur_dir = os.path.join(data_dir, "duration")
f0_dir = os.path.join(data_dir, "f0")
energy_dir = os.path.join(data_dir, "energy")
frame_f0_dir = os.path.join(data_dir, "frame_f0")
frame_uv_dir = os.path.join(data_dir, "frame_uv")
se_dir = os.path.join(data_dir, "se")
if self.mas_enable:
self.with_duration = False
else:
self.with_duration = os.path.exists(dur_dir)
items = []
logging.info("Loading metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
mel_file = os.path.join(mel_dir, index + ".npy")
if self.with_duration:
dur_file = os.path.join(dur_dir, index + ".npy")
else:
dur_file = None
f0_file = os.path.join(f0_dir, index + ".npy")
energy_file = os.path.join(energy_dir, index + ".npy")
frame_f0_file = os.path.join(frame_f0_dir, index + ".npy")
frame_uv_file = os.path.join(frame_uv_dir, index + ".npy")
aug_ling_txt = aug_ling_dict.get(index, None)
if self.fp_enable and aug_ling_txt is None:
logging.warning(f"Missing fpadd meta for {index}")
continue
se_path = os.path.join(se_dir, "se.npy")
if self.se_enable:
if not os.path.exists(se_path):
logging.warning("Missing se meta")
continue
items.append(
(
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
se_path,
)
)
return items
def load_fpadd_meta(self, metafile):
with open(metafile, "r") as f:
lines = f.readlines()
items = []
logging.info("Loading fpadd metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
items.append((ling_txt,))
return items
@staticmethod
def gen_metafile(
raw_meta_file,
out_dir,
train_meta_file,
valid_meta_file,
badlist=None,
split_ratio=0.98,
se_enable=False,
):
with open(raw_meta_file, "r") as f:
lines = f.readlines()
se_dir = os.path.join(out_dir, "se")
frame_f0_dir = os.path.join(out_dir, "frame_f0")
frame_uv_dir = os.path.join(out_dir, "frame_uv")
mel_dir = os.path.join(out_dir, "mel")
duration_dir = os.path.join(out_dir, "duration")
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(train_meta_file, "w") as f:
for line in lines[:num_train]:
index = line.split("\t")[0]
if badlist is not None and index in badlist:
continue
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + ".npy")
):
continue
if se_enable:
if os.path.exists(se_dir) and not os.path.exists(
os.path.join(se_dir, "se.npy")
):
continue
f.write(line)
with open(valid_meta_file, "w") as f:
for line in lines[num_train:]:
index = line.split("\t")[0]
if badlist is not None and index in badlist:
continue
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + ".npy")
):
continue
if se_enable:
if os.path.exists(se_dir) and not os.path.exists(
os.path.join(se_dir, "se.npy")
):
continue
f.write(line)
# TODO: implement collate_fn
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
if self.with_duration:
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
lfeat_type_index = 0
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
if self.ling_unit.using_byte():
# for byte-based model only
inputs_byte_index = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack([inputs_byte_index], dim=2)
else:
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
inputs_sy = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2
)
# emotion category
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
data_dict["input_emotions"] = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# speaker category
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
if self.se_enable:
data_dict["input_speakers"] = self.padder._prepare_targets(
[x[7].repeat(len(x[0][0]), axis=0) for x in batch],
max_input_length,
0.0,
)
else:
data_dict["input_speakers"] = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# fp label category
if self.fp_enable:
data_dict["fp_label"] = self.padder._prepare_scalar_inputs(
[x[6] for x in batch],
max_input_length,
0,
).long()
data_dict["valid_input_lengths"] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict["valid_output_lengths"] = torch.as_tensor(
[len(x[1]) for x in batch], dtype=torch.long
)
max_output_length = torch.max(data_dict["valid_output_lengths"]).item()
max_output_round_length = self.padder._round_up(max_output_length, self.r)
data_dict["mel_targets"] = self.padder._prepare_targets(
[x[1] for x in batch], max_output_round_length, 0.0
)
if self.with_duration:
data_dict["durations"] = self.padder._prepare_durations(
[x[2] for x in batch], max_dur_length, max_output_round_length
)
else:
data_dict["durations"] = None
if self.with_duration:
if self.fp_enable:
feats_padding_length = max_dur_length
else:
feats_padding_length = max_input_length
else:
feats_padding_length = max_output_round_length
data_dict["pitch_contours"] = self.padder._prepare_scalar_inputs(
[x[3] for x in batch], feats_padding_length, 0.0
).float()
data_dict["energy_contours"] = self.padder._prepare_scalar_inputs(
[x[4] for x in batch], feats_padding_length, 0.0
).float()
if self.with_duration:
data_dict["attn_priors"] = None
else:
data_dict["attn_priors"] = torch.zeros(
len(batch), max_output_round_length, max_input_length
)
for i in range(len(batch)):
attn_prior = batch[i][5]
data_dict["attn_priors"][
i, : attn_prior.shape[0], : attn_prior.shape[1]
] = attn_prior
return data_dict
# TODO: implement get_am_datasets
def get_am_datasets(
metafile,
root_dir,
config,
allow_cache,
split_ratio=0.98,
se_enable=False,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if fp_enable:
am_train_fn = "am_fprm_train.lst"
am_valid_fn = "am_fprm_valid.lst"
else:
am_train_fn = "am_train.lst"
am_valid_fn = "am_valid.lst"
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, am_train_fn)
valid_meta = os.path.join(data_dir, am_valid_fn)
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
AM_Dataset.gen_metafile(
raw_metafile, data_dir, train_meta, valid_meta, split_ratio, se_enable
)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = AM_Dataset(config, train_meta_lst, root_dir, allow_cache)
valid_dataset = AM_Dataset(config, valid_meta_lst[:50], root_dir, allow_cache)
return train_dataset, valid_dataset
class MaskingActor(object):
def __init__(self, mask_ratio=0.15):
super(MaskingActor, self).__init__()
self.mask_ratio = mask_ratio
pass
def _get_random_mask(self, length, p1=0.15):
mask = np.random.uniform(0, 1, length)
index = 0
while index < len(mask):
if mask[index] < p1:
mask[index] = 1
else:
mask[index] = 0
index += 1
return mask
def _input_bert_masking(
self,
sequence_array,
nb_symbol_category,
mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
):
sequence_array_mask = sequence_array.copy()
mask_id = np.where(mask == 1)[0]
mask_len = len(mask_id)
rand = np.arange(mask_len)
np.random.shuffle(rand)
# [MASK]
mask_id_p2 = mask_id[rand[0 : int(math.floor(mask_len * p2))]]
if len(mask_id_p2) > 0:
sequence_array_mask[mask_id_p2] = mask_symbol_id
# rand
mask_id_p3 = mask_id[
rand[
int(math.floor(mask_len * p2)) : int(math.floor(mask_len * p2))
+ int(math.floor(mask_len * p3))
]
]
if len(mask_id_p3) > 0:
sequence_array_mask[mask_id_p3] = random.randint(0, nb_symbol_category - 1)
# ori
# do nothing
return sequence_array_mask
class BERT_Text_Dataset(torch.utils.data.Dataset):
"""
provide (ling, ling_sy_masked, bert_mask) pair
"""
def __init__(
self,
config,
metafile,
root_dir,
allow_cache=False,
):
self.meta = []
self.config = config
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[BERT_Text_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data dir not found: {}".format(data_dir))
raise ValueError(
"[BERT_Text_Dataset] data dir: {} not found".format(data_dir)
)
self.meta.extend(self.load_meta(meta_file, data_dir))
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config)
self.padder = Padder()
self.masking_actor = MaskingActor(
self.config["Model"]["KanTtsTextsyBERT"]["params"]["mask_ratio"]
)
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
# TODO: implement __getitem__
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
ling_data = self.caches[idx][0]
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
return (ling_data, ling_sy_masked_data, bert_mask)
ling_txt = self.meta[idx]
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
if self.allow_cache:
self.caches[idx] = (ling_data,)
return (ling_data, ling_sy_masked_data, bert_mask)
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
items = []
logging.info("Loading metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
items.append((ling_txt))
return items
@staticmethod
def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98):
with open(raw_meta_file, "r") as f:
lines = f.readlines()
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(os.path.join(out_dir, "bert_train.lst"), "w") as f:
for line in lines[:num_train]:
f.write(line)
with open(os.path.join(out_dir, "bert_valid.lst"), "w") as f:
for line in lines[num_train:]:
f.write(line)
def bert_masking(self, ling_data):
length = len(ling_data[0])
mask = self.masking_actor._get_random_mask(
length, p1=self.masking_actor.mask_ratio
)
mask[-1] = 0
# sy_masked
sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0]
ling_sy_masked_data = self.masking_actor._input_bert_masking(
ling_data[0],
self.ling_unit.get_unit_size()["sy"],
sy_mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
)
return (mask, ling_sy_masked_data)
# TODO: implement collate_fn
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
lfeat_type = self.ling_unit._lfeat_type_list[0]
targets_sy = self.padder._prepare_scalar_inputs(
[x[0][0] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# sy masked
inputs_sy = self.padder._prepare_scalar_inputs(
[x[1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][2] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][3] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2
)
data_dict["valid_input_lengths"] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict["targets"] = targets_sy
data_dict["bert_masks"] = self.padder._prepare_scalar_inputs(
[x[2] for x in batch], max_input_length, 0.0
)
return data_dict
def get_bert_text_datasets(
metafile,
root_dir,
config,
allow_cache,
split_ratio=0.98,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, "bert_train.lst")
valid_meta = os.path.join(data_dir, "bert_valid.lst")
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
BERT_Text_Dataset.gen_metafile(raw_metafile, data_dir, split_ratio)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = BERT_Text_Dataset(config, train_meta_lst, root_dir, allow_cache)
valid_dataset = BERT_Text_Dataset(config, valid_meta_lst, root_dir, allow_cache)
return train_dataset, valid_dataset
import torch
from torch.nn.parallel import DistributedDataParallel
from kantts.models.hifigan.hifigan import ( # NOQA
Generator, # NOQA
MultiScaleDiscriminator, # NOQA
MultiPeriodDiscriminator, # NOQA
MultiSpecDiscriminator, # NOQA
)
import kantts
import kantts.train.scheduler
from kantts.models.sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT # NOQA
from kantts.utils.ling_unit.ling_unit import get_fpdict
from .pqmf import PQMF
def optimizer_builder(model_params, opt_name, opt_params):
opt_cls = getattr(torch.optim, opt_name)
optimizer = opt_cls(model_params, **opt_params)
return optimizer
def scheduler_builder(optimizer, sche_name, sche_params):
scheduler_cls = getattr(kantts.train.scheduler, sche_name)
scheduler = scheduler_cls(optimizer, **sche_params)
return scheduler
def hifigan_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["discriminator"] = {}
optimizer["discriminator"] = {}
scheduler["discriminator"] = {}
for model_name in config["Model"].keys():
if model_name == "Generator":
params = config["Model"][model_name]["params"]
model["generator"] = Generator(**params).to(device)
optimizer["generator"] = optimizer_builder(
model["generator"].parameters(),
config["Model"][model_name]["optimizer"].get("type", "Adam"),
config["Model"][model_name]["optimizer"].get("params", {}),
)
scheduler["generator"] = scheduler_builder(
optimizer["generator"],
config["Model"][model_name]["scheduler"].get("type", "StepLR"),
config["Model"][model_name]["scheduler"].get("params", {}),
)
else:
params = config["Model"][model_name]["params"]
model["discriminator"][model_name] = globals()[model_name](**params).to(
device
)
optimizer["discriminator"][model_name] = optimizer_builder(
model["discriminator"][model_name].parameters(),
config["Model"][model_name]["optimizer"].get("type", "Adam"),
config["Model"][model_name]["optimizer"].get("params", {}),
)
scheduler["discriminator"][model_name] = scheduler_builder(
optimizer["discriminator"][model_name],
config["Model"][model_name]["scheduler"].get("type", "StepLR"),
config["Model"][model_name]["scheduler"].get("params", {}),
)
out_channels = config["Model"]["Generator"]["params"]["out_channels"]
if out_channels > 1:
model["pqmf"] = PQMF(subbands=out_channels, **config.get("pqmf", {})).to(device)
# FIXME: pywavelets buffer leads to gradient error in DDP training
# Solution: https://github.com/pytorch/pytorch/issues/22095
if distributed:
model["generator"] = DistributedDataParallel(
model["generator"],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
for model_name in model["discriminator"].keys():
model["discriminator"][model_name] = DistributedDataParallel(
model["discriminator"][model_name],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
return model, optimizer, scheduler
# TODO: some parsing
def sambert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["KanTtsSAMBERT"] = KanTtsSAMBERT(
config["Model"]["KanTtsSAMBERT"]["params"]
).to(device)
fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if fp_enable:
fp_dict = {
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
for k, v in get_fpdict(config).items()
}
model["KanTtsSAMBERT"].fp_dict = fp_dict
optimizer["KanTtsSAMBERT"] = optimizer_builder(
model["KanTtsSAMBERT"].parameters(),
config["Model"]["KanTtsSAMBERT"]["optimizer"].get("type", "Adam"),
config["Model"]["KanTtsSAMBERT"]["optimizer"].get("params", {}),
)
scheduler["KanTtsSAMBERT"] = scheduler_builder(
optimizer["KanTtsSAMBERT"],
config["Model"]["KanTtsSAMBERT"]["scheduler"].get("type", "StepLR"),
config["Model"]["KanTtsSAMBERT"]["scheduler"].get("params", {}),
)
if distributed:
model["KanTtsSAMBERT"] = DistributedDataParallel(
model["KanTtsSAMBERT"], device_ids=[rank], output_device=rank
)
return model, optimizer, scheduler
def sybert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["KanTtsTextsyBERT"] = KanTtsTextsyBERT(
config["Model"]["KanTtsTextsyBERT"]["params"]
).to(device)
optimizer["KanTtsTextsyBERT"] = optimizer_builder(
model["KanTtsTextsyBERT"].parameters(),
config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("type", "Adam"),
config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("params", {}),
)
scheduler["KanTtsTextsyBERT"] = scheduler_builder(
optimizer["KanTtsTextsyBERT"],
config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("type", "StepLR"),
config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("params", {}),
)
if distributed:
model["KanTtsTextsyBERT"] = DistributedDataParallel(
model["KanTtsTextsyBERT"], device_ids=[rank], output_device=rank
)
return model, optimizer, scheduler
# TODO: implement a builder for specific model
model_dict = {
"hifigan": hifigan_model_builder,
"sambert": sambert_model_builder,
"sybert": sybert_model_builder,
}
def model_builder(config, device="cpu", rank=0, distributed=False):
builder_func = model_dict[config["model_type"]]
model, optimizer, scheduler = builder_func(config, device, rank, distributed)
return model, optimizer, scheduler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment