masking.yaml 4.72 KB
Newer Older
wxj's avatar
wxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: "masking"

model:
  sample_rate: 16000
  skip_nan_grad: false
  num_outputs: 1

  train_ds:
    manifest_filepath: ???
    input_key: audio_filepath # key of the input signal path in the manifest
    target_key: target_filepath # key of the target signal path in the manifest
    target_channel_selector: 0 # target signal is the first channel from files in target_key
    audio_duration: 4.0 # in seconds, audio segment duration for training
    random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment
    min_duration: ${model.train_ds.audio_duration}
    batch_size: 64 # batch size may be increased based on the available memory
    shuffle: true
    num_workers: 8
    pin_memory: true

  validation_ds:
    manifest_filepath: ???
    input_key: audio_filepath # key of the input signal path in the manifest
    target_key: target_filepath
    target_channel_selector: 0 # target signal is the first channel from files in target_key
    batch_size: 64 # batch size may be increased based on the available memory
    shuffle: false
    num_workers: 4
    pin_memory: true

  test_ds:
    manifest_filepath: ???
    input_key: audio_filepath # key of the input signal path in the manifest
    target_key: target_filepath # key of the target signal path in the manifest
    target_channel_selector: 0 # target signal is the first channel from files in target_key
    batch_size: 1 # batch size may be increased based on the available memory
    shuffle: false
    num_workers: 4
    pin_memory: true

  encoder:
    _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
    fft_length: 512 # Length of the window and FFT for calculating spectrogram
    hop_length: 256 # Hop length for calculating spectrogram
    power: null

  decoder:
    _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
    fft_length: 512 # Length of the window and FFT for calculating spectrogram
    hop_length: 256 # Hop length for calculating spectrogram

  mask_estimator:
    _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorRNN
    num_outputs: ${model.num_outputs}
    num_subbands: 257 # Number of subbands of the input spectrogram
    num_features: 256 # Number of features at RNN input
    num_layers: 5 # Number of RNN layers
    bidirectional: true # Use bi-directional RNN
    
  mask_processor:
    _target_: nemo.collections.asr.modules.audio_modules.MaskReferenceChannel # Apply mask on the reference channel
    ref_channel: 0 # Reference channel for the output

  loss:
    _target_: nemo.collections.asr.losses.SDRLoss
    scale_invariant: true # Use scale-invariant SDR

  metrics:
    val:
      sdr: # output SDR
        _target_: torchmetrics.audio.SignalDistortionRatio
    test:
      sdr_ch0: # SDR on output channel 0
        _target_: torchmetrics.audio.SignalDistortionRatio
        channel: 0
    
  optim:
    name: adamw
    lr: 1e-4
    # optimizer arguments
    betas: [0.9, 0.98]
    weight_decay: 1e-3

trainer:
  devices: -1 # number of GPUs, -1 would use all available GPUs
  num_nodes: 1
  max_epochs: -1
  max_steps: -1 # computed at runtime if not set
  val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
  accelerator: auto
  strategy: ddp
  accumulate_grad_batches: 1
  gradient_clip_val: null
  precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
  log_every_n_steps: 25  # Interval of logging.
  enable_progress_bar: true
  num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
  check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
  sync_batchnorm: true
  enable_checkpointing: False  # Provided by exp_manager
  logger: false  # Provided by exp_manager

exp_manager:
  exp_dir: null
  name: ${name}
  create_tensorboard_logger: true
  create_checkpoint_callback: true
  checkpoint_callback_params:
    # in case of multiple validation sets, first one is used
    monitor: "val_loss"
    mode: "min"
    save_top_k: 5
    always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

  resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
  # you need to set these two to true to continue the training
  resume_if_exists: false
  resume_ignore_no_checkpoint: false

  # You may use this section to create a W&B logger
  create_wandb_logger: false
  wandb_logger_kwargs:
    name: null
    project: null