fastpitch_ssl.yaml 4.87 KB
Newer Older
wxj's avatar
wxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# This config contains the default values for training FastPitch model with aligner on LJSpeech dataset.
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: FastPitch

train_dataset: ???
validation_datasets: ???
ssl_model_ckpt_path: ???
hifi_ckpt_path: ???
sup_data_dir: null

# LJSpeech stats (per frame)
# ignored if pitch_normalization: speaker_wise
pitch_mean: ??? #212.35873413085938
pitch_std: ??? #68.52806091308594

# Default values for dataset with sample_rate=22050
sample_rate: 22050
n_mel_channels: 80
n_window_size: 1024
n_window_stride: 256
n_fft: 1024
lowfreq: 0
highfreq: 8000
window: hann


ssl_content_emb_type: "embedding_and_probs"
speaker_stats_pitch_fp: null
pitch_normalization: speaker_wise
use_unique_tokens: true
speaker_conditioning_type: per_sample
segment_speaker_embedding: true
ssl_downsampling_factor: 4 # How many mel-spectrogram frames map to one content embedding in the SSL model

model:
  ssl_model_ckpt_path: ${ssl_model_ckpt_path}
  ssl_downsampling_factor: ${ssl_downsampling_factor}
  use_encoder: true
  use_duration_predictor: ${use_unique_tokens}
  pitch_conditioning: true
  pitch_loss_scale: 1.0
  learn_alignment: true
  bin_loss_warmup_epochs: 100

  n_speakers: 1
  n_datasets: 1
  max_token_duration: 75
  symbols_embedding_dim: 384
  pitch_embedding_kernel_size: 3

  sample_rate: ${sample_rate}
  n_mel_channels: ${n_mel_channels}
  n_window_size: ${n_window_size}
  n_window_stride: ${n_window_stride}
  n_fft: ${n_fft}
  lowfreq: ${lowfreq}
  highfreq: ${highfreq}
  window: ${window}
  
  content_emb_indim: 174
  speaker_emb_indim: 256
  content_emb_outdim: 192
  speaker_emb_outdim: 192
  
  train_ds:
    dataset:
      _target_: nemo.collections.tts.data.dataset.FastPitchSSLDataset
      manifest_filepath: ${train_dataset}
      sample_rate: ${model.sample_rate}
      ssl_content_emb_type: ${ssl_content_emb_type}
      pitch_conditioning: true
      pitch_normalization: ${pitch_normalization}
      pitch_mean: ${pitch_mean}
      pitch_std: ${pitch_std}
      speaker_stats_pitch_fp: ${speaker_stats_pitch_fp}
      min_duration: 0.5
      max_duration: 16.0
      pad_multiple: 1024
      speaker_conditioning_type: ${speaker_conditioning_type}
      sup_data_dir: ${sup_data_dir}

    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 2
      num_workers: 8
      pin_memory: true

  validation_ds:
    dataset:
      _target_: nemo.collections.tts.data.dataset.FastPitchSSLDataset
      manifest_filepath: ${validation_datasets}
      sample_rate: ${model.sample_rate}
      ssl_content_emb_type: ${ssl_content_emb_type}
      pitch_conditioning: true
      pitch_normalization: ${pitch_normalization}
      pitch_mean: ${pitch_mean}
      pitch_std: ${pitch_std}
      speaker_stats_pitch_fp: ${speaker_stats_pitch_fp}
      min_duration: 0.5
      max_duration: 16.0
      pad_multiple: 1024
      speaker_conditioning_type: ${speaker_conditioning_type}
      sup_data_dir: ${sup_data_dir}

    dataloader_params:
      drop_last: false
      shuffle: false
      batch_size: 2
      num_workers: 0
      pin_memory: true

  # both encoder and decoder have same architecture, FFTransformerDecoder 
  encoder: #n_embed and padding_idx are added by the model
    _target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder
    n_layer: 6
    n_head: 1
    d_model: ${model.symbols_embedding_dim}
    d_head: 64
    d_inner: 1536
    kernel_size: 3
    dropout: 0.1
    dropatt: 0.1
    dropemb: 0.0

  output_fft:
    _target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder
    n_layer: 6
    n_head: 1
    d_model: ${model.symbols_embedding_dim}
    d_head: 64
    d_inner: 1536
    kernel_size: 3
    dropout: 0.1
    dropatt: 0.1
    dropemb: 0.0

  duration_predictor:
    _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
    input_size: ${model.symbols_embedding_dim}
    kernel_size: 3
    filter_size: 256
    dropout: 0.1
    n_layers: 2

  pitch_predictor:
    _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
    input_size: ${model.symbols_embedding_dim}
    kernel_size: 3
    filter_size: 256
    dropout: 0.1
    n_layers: 2

  optim:
    _target_: torch.optim.AdamW
    lr: 0.0002
    betas: [0.8, 0.99]

trainer:
  num_nodes: 1
  devices: -1
  accelerator: gpu
  strategy: ddp
  precision: 32
  max_epochs: 1000
  accumulate_grad_batches: 1
  gradient_clip_val: 1000.0
  enable_checkpointing: False # Provided by exp_manager
  logger: false # Provided by exp_manager
  log_every_n_steps: 100
  check_val_every_n_epoch: 5
  benchmark: false

exp_manager:
  exp_dir: null
  name: ${name}
  create_tensorboard_logger: true
  create_checkpoint_callback: true
  checkpoint_callback_params:
    monitor: v_loss 
  resume_if_exists: false
  resume_ignore_no_checkpoint: false