g2p_t5.yaml 2.5 KB
Newer Older
wxj's avatar
wxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
name: T5G2P

# Dataset info
train_manifest: ???
validation_manifest: ???
test_manifest: null
do_training: True
do_testing: False
pretrained_model: null # path to .nemo file or model name from list_available_models()

model:
  model_name: "google/byt5-small" # One of: google/byt5-small/base/large/xl or t5-small/base/large/3b/11b
  max_source_len: 256
  max_target_len: 512
  do_lower: false

  train_ds:
    manifest_filepath: ${train_manifest}
    dataset:
      _target_: "nemo.collections.tts.g2p.data.t5.T5G2PDataset"
      phoneme_field: "text" # name of the field in manifest_filepath for ground truth phonemes
      grapheme_field: "text_graphemes" # name of the field in manifest_filepath for input grapheme text
    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 20
      num_workers: 4

  validation_ds:
    manifest_filepath: ${validation_manifest}
    dataset:
      _target_: "nemo.collections.tts.g2p.data.t5.T5G2PDataset"
      phoneme_field: "text" # name of the field in manifest_filepath for ground truth phonemes
      grapheme_field: "text_graphemes" # name of the field in manifest_filepath for input grapheme text
    dataloader_params:
      drop_last: false
      shuffle: false
      batch_size: 20
      num_workers: 4

  test_ds:
    manifest_filepath: ${test_manifest}
    dataset:
      _target_: "nemo.collections.tts.g2p.data.t5.T5G2PDataset"
      phoneme_field: "text" # name of the field in manifest_filepath for ground truth phonemes
      grapheme_field: "text_graphemes" # name of the field in manifest_filepath for input grapheme text
    dataloader_params:
      drop_last: false
      shuffle: false
      batch_size: 20
      num_workers: 4

  optim:
    name: adamw
    lr: 2e-4
    weight_decay: 0.01
    # scheduler setup
    sched:
      name: WarmupAnnealing

      # pytorch lightning args
      monitor: val_token_precision
      reduce_on_plateau: false

      # scheduler config override
      warmup_steps: null
      warmup_ratio: 0.1
      last_epoch: -1

trainer:
  devices: 1 # number of gpus
  max_epochs: 5
  num_nodes: 1
  accelerator: gpu
  strategy: ddp
  accumulate_grad_batches: 1
  enable_checkpointing: False  # Provided by exp_manager
  logger: False  # Provided by exp_manager
  log_every_n_steps: 200
  check_val_every_n_epoch: 1

exp_manager:
  exp_dir: null
  name: ${name}
  create_tensorboard_logger: True
  create_checkpoint_callback: True
  checkpoint_callback_params:
    save_top_k: 1
    monitor: "val_per"
    mode: "min"
    save_best_model: true