Commit 928248d7 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Improve hubert recipe for pre-training and fine-tuning (#2744)

Summary:
following pr https://github.com/pytorch/audio/issues/2716
- For preprocessing
  - The HuBERT feature takes lots of memory which may not fit some machines. Enable to use a subset of feature for training a k-means model.

- For pre-training
  - Normalize the loss based on the total number of masked frames across all GPUs.
  - Use mixed precision training. fp16 is not well supported in pytorch_lightning.
  - Log accuracies of masked/unmasked frames during training.
  - Clip the gradients with norm `10.0`.

- For ASR fine-tuning
  - Normalize the loss based on the total number of batches across all GPUs, same as in the conformer recipe of TorchAudio.
  - Use mixed precision training.
  - Add "|" after the end of transcription to capture the silence/word termination, same as in fairseq recipe.

- Update the WER results on LibriSpeech dev and test sets.

|                   | WER% (Viterbi)|  WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean         |       10.9    |       4.2     |
| dev-other         |       17.5    |       9.4     |
| test-clean        |       10.9    |       4.4     |
| test-other        |       17.8    |       9.5     |

Pull Request resolved: https://github.com/pytorch/audio/pull/2744

Reviewed By: carolineechen

Differential Revision: D40282322

Pulled By: nateanl

fbshipit-source-id: 4723584c912e70e8970149fe09de005385eaab90
parent 97baba1b
...@@ -29,7 +29,7 @@ After the first iteration of pre-training, the intermediate transformer layer's ...@@ -29,7 +29,7 @@ After the first iteration of pre-training, the intermediate transformer layer's
Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance. Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance.
``` ```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500 srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/ --num-rank 40 checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500 --percent 0.1
``` ```
### Pre-training (2nd iteration) ### Pre-training (2nd iteration)
...@@ -67,9 +67,9 @@ srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_fi ...@@ -67,9 +67,9 @@ srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_fi
### CTC Decoding with language model ### CTC Decoding with language model
torchaudio provides a CTCDecoder feature that is based on [Flashlight](https://github.com/flashlight/flashlight). The decoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model. torchaudio provides a CTCDecoder feature that is based on [Flashlight](https://github.com/flashlight/flashlight). The decoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model.
Sample SLURM command for evaluation with KenLM language model: Sample SLURM command for evaluation with KenLM language model (use the checkpoint that has the lowest validation loss):
``` ```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59 srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=106-step\=19500.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
``` ```
### WER results ### WER results
...@@ -77,7 +77,7 @@ The table below contains WER results for fine-tuning HuBERT Base model on `10h` ...@@ -77,7 +77,7 @@ The table below contains WER results for fine-tuning HuBERT Base model on `10h`
| | WER% (Viterbi)| WER% (KenLM) | | | WER% (Viterbi)| WER% (KenLM) |
|:-----------------:|--------------:|--------------:| |:-----------------:|--------------:|--------------:|
| dev-clean | 10.7 | 4.4 | | dev-clean | 10.9 | 4.2 |
| dev-other | 18.3 | 9.7 | | dev-other | 17.5 | 9.4 |
| test-clean | 10.8 | 4.4 | | test-clean | 10.9 | 4.4 |
| test-other | 18.5 | 10.1 | | test-other | 17.8 | 9.5 |
...@@ -463,6 +463,8 @@ class CollateFnLibriLightLimited: ...@@ -463,6 +463,8 @@ class CollateFnLibriLightLimited:
label2id = _get_label2id() label2id = _get_label2id()
for sample in batch: for sample in batch:
waveform, transcript = sample[0], sample[2] waveform, transcript = sample[0], sample[2]
# add one "|" symbol after the end of transcription as the word termination
transcript = transcript + "|"
label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()]) label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()])
audio_length = waveform.size(1) audio_length = waveform.size(1)
label_length = label.size(0) label_length = label.size(0)
......
import argparse import argparse
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -36,10 +36,9 @@ def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) ...@@ -36,10 +36,9 @@ def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0)
Returns: Returns:
(List of str): The decoding result. List of string in lower case. (List of str): The decoding result. List of string in lower case.
""" """
hypothesis = F.log_softmax(emission, dim=-1) hypothesis = emission.argmax(-1).unique_consecutive()
hypothesis = hypothesis.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx] hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ") hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ").strip()
return hypothesis.split() return hypothesis.split()
...@@ -47,7 +46,7 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]: ...@@ -47,7 +46,7 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with a KenLM language model. """Run CTC decoding with a KenLM language model.
Args: Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens). emission (torch.Tensor): Output of CTC layer. Tensor with dimensions `(..., time, num_tokens)`.
decoder (CTCDecoder): The initialized CTCDecoder. decoder (CTCDecoder): The initialized CTCDecoder.
Returns: Returns:
...@@ -55,13 +54,19 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]: ...@@ -55,13 +54,19 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
""" """
hypothesis = decoder(emission) hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words hypothesis = hypothesis[0][0].words
hypothesis = [word for word in hypothesis if word != " "]
return hypothesis return hypothesis
def run_inference(args): def run_inference(args):
if args.use_gpu:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load the fine-tuned HuBERTPretrainModel from checkpoint. # Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint) model = _load_checkpoint(args.checkpoint)
model.eval() model.eval().to(device)
if args.use_lm: if args.use_lm:
# get decoder files # get decoder files
...@@ -92,13 +97,14 @@ def run_inference(args): ...@@ -92,13 +97,14 @@ def run_inference(args):
transcript = transcript.strip().lower().strip().replace("\n", "") transcript = transcript.strip().lower().strip().replace("\n", "")
with torch.inference_mode(): with torch.inference_mode():
emission, _ = model(waveform) emission, _ = model(waveform.to(device))
emission = F.log_softmax(emission, dim=-1)
if args.use_lm: if args.use_lm:
hypothesis = _ctc_decode(emission, decoder) hypothesis = _ctc_decode(emission.cpu(), decoder)
else: else:
hypothesis = _viterbi_decode(emission, id2token) hypothesis = _viterbi_decode(emission, id2token)
total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), hypothesis) total_edit_distance += torchaudio.functional.edit_distance(hypothesis, transcript.split())
total_length += len(transcript.split()) total_length += len(transcript.split())
if idx % 100 == 0: if idx % 100 == 0:
...@@ -138,9 +144,9 @@ def _parse_args(): ...@@ -138,9 +144,9 @@ def _parse_args():
) )
parser.add_argument( parser.add_argument(
"--beam-size-token", "--beam-size-token",
type=Optional[int], type=int,
default=None, default=29,
help="Number of tokens to consider at each beam search step. (Default: None)", help="Number of tokens to consider at each beam search step. (Default: 29)",
) )
parser.add_argument( parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)" "--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
...@@ -161,6 +167,7 @@ def _parse_args(): ...@@ -161,6 +167,7 @@ def _parse_args():
"--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)" "--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
) )
parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)") parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--use-gpu", action="store_true", help="Whether to use GPU for decoding.")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.") parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args() return parser.parse_args()
......
...@@ -16,6 +16,7 @@ from lightning import HuBERTFineTuneModule ...@@ -16,6 +16,7 @@ from lightning import HuBERTFineTuneModule
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -29,10 +30,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): ...@@ -29,10 +30,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
def run_train(args): def run_train(args):
seed_everything(1337)
checkpoint_dir = args.exp_dir / f"checkpoints_{args.model_name}" checkpoint_dir = args.exp_dir / f"checkpoints_{args.model_name}"
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
monitor="Losses/val_loss", monitor="val_loss",
mode="min", mode="min",
save_top_k=5, save_top_k=5,
save_weights_only=False, save_weights_only=False,
...@@ -40,7 +42,7 @@ def run_train(args): ...@@ -40,7 +42,7 @@ def run_train(args):
) )
train_checkpoint = ModelCheckpoint( train_checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
monitor="Losses/train_loss", monitor="train_loss",
mode="min", mode="min",
save_top_k=5, save_top_k=5,
save_weights_only=False, save_weights_only=False,
...@@ -60,7 +62,8 @@ def run_train(args): ...@@ -60,7 +62,8 @@ def run_train(args):
replace_sampler_ddp=False, replace_sampler_ddp=False,
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
accumulate_grad_batches=args.accumulate_grad_batches, val_check_interval=500,
check_val_every_n_epoch=None,
) )
model = HuBERTFineTuneModule( model = HuBERTFineTuneModule(
...@@ -73,6 +76,7 @@ def run_train(args): ...@@ -73,6 +76,7 @@ def run_train(args):
mask_prob=args.mask_prob, mask_prob=args.mask_prob,
mask_channel_prob=args.mask_channel_prob, mask_channel_prob=args.mask_channel_prob,
mask_channel_length=args.mask_channel_length, mask_channel_length=args.mask_channel_length,
num_classes=args.num_classes,
aux_num_out=args.aux_num_out, aux_num_out=args.aux_num_out,
checkpoint=args.checkpoint, checkpoint=args.checkpoint,
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
...@@ -87,7 +91,7 @@ def run_train(args): ...@@ -87,7 +91,7 @@ def run_train(args):
hold_updates=args.hold_updates, hold_updates=args.hold_updates,
decay_updates=args.decay_updates, decay_updates=args.decay_updates,
) )
trainer.fit(model) trainer.fit(model, ckpt_path=args.resume_checkpoint)
def _parse_args(): def _parse_args():
...@@ -101,6 +105,18 @@ def _parse_args(): ...@@ -101,6 +105,18 @@ def _parse_args():
required=True, required=True,
help="Path to the LibriSpeech and LibriLightLimited datasets.", help="Path to the LibriSpeech and LibriLightLimited datasets.",
) )
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the pre-trained HuBERTPretrainModel checkpoint as the initialization.",
)
parser.add_argument(
"--resume-checkpoint",
default=None,
type=str,
help="The path to the checkpoint to resume the fine-tuning if training fails in the middle.",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
default=pathlib.Path("./exp_finetune"), default=pathlib.Path("./exp_finetune"),
...@@ -141,7 +157,7 @@ def _parse_args(): ...@@ -141,7 +157,7 @@ def _parse_args():
) )
parser.add_argument( parser.add_argument(
"--encoder-layer-drop", "--encoder-layer-drop",
default=0.1, default=0.05,
type=float, type=float,
help="Probability to drop each encoder layer during training. (Default: 0.1)", help="Probability to drop each encoder layer during training. (Default: 0.1)",
) )
...@@ -164,10 +180,11 @@ def _parse_args(): ...@@ -164,10 +180,11 @@ def _parse_args():
help="Minimum space between spans (if no overlap is enabled) for channel masking." "(Default: 64)", help="Minimum space between spans (if no overlap is enabled) for channel masking." "(Default: 64)",
) )
parser.add_argument( parser.add_argument(
"--accumulate-grad-batches", "--num-classes",
default=1, choices=[100, 500],
type=int, type=int,
help="Number of batches to accumulate the gradients during training. (Default: 1)", default=500,
help="The ``num_class`` in the pre-trained checkpoint. (Default: 500)",
) )
parser.add_argument( parser.add_argument(
"--aux-num-out", "--aux-num-out",
...@@ -176,13 +193,7 @@ def _parse_args(): ...@@ -176,13 +193,7 @@ def _parse_args():
help="The dimension of linear layer for CTC training. (Default: 29)", help="The dimension of linear layer for CTC training. (Default: 29)",
) )
parser.add_argument( parser.add_argument(
"--checkpoint", "--learning-rate", default=5e-5, type=float, help="The learning rate of Adam optimizer. (Default: 5e-5)"
type=str,
required=True,
help="Path to the pre-trained HuBERTPretrainModel checpoint.",
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, help="The learning rate of Adam optimizer. (Default: 2e-5)"
) )
parser.add_argument( parser.add_argument(
"--betas", "--betas",
...@@ -198,7 +209,7 @@ def _parse_args(): ...@@ -198,7 +209,7 @@ def _parse_args():
) )
parser.add_argument( parser.add_argument(
"--weight-decay", "--weight-decay",
default=1e-6, default=0.0,
type=float, type=float,
help="Weight decay (L2 penalty) (Default: 0.0)", help="Weight decay (L2 penalty) (Default: 0.0)",
) )
......
import math import math
from typing import Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -94,6 +94,29 @@ class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler): ...@@ -94,6 +94,29 @@ class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
return [base_lr * self.final_lr_scale for base_lr in self.base_lrs] return [base_lr * self.final_lr_scale for base_lr in self.base_lrs]
def _compute_accuracy(logits: torch.Tensor):
with torch.no_grad():
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
def _reset_stats():
return {
"train": {
"correct": 0.0,
"count": 0.0,
},
"val": {
"correct": 0.0,
"count": 0.0,
},
}
class HuBERTPreTrainModule(LightningModule): class HuBERTPreTrainModule(LightningModule):
def __init__( def __init__(
self, self,
...@@ -109,6 +132,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -109,6 +132,7 @@ class HuBERTPreTrainModule(LightningModule):
betas: Tuple[float, float], betas: Tuple[float, float],
eps: float, eps: float,
weight_decay: float, weight_decay: float,
clip_norm: Optional[float],
warmup_updates: int, warmup_updates: int,
max_updates: int, max_updates: int,
): ):
...@@ -124,29 +148,71 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -124,29 +148,71 @@ class HuBERTPreTrainModule(LightningModule):
self.model = torchaudio.models.hubert_pretrain_xlarge() self.model = torchaudio.models.hubert_pretrain_xlarge()
else: else:
raise ValueError(f"Unsupported model name: {model_name}") raise ValueError(f"Unsupported model name: {model_name}")
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
self.loss = hubert_loss self.loss = hubert_loss
self.optimizer = torch.optim.AdamW( self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
) )
self.clip_norm = clip_norm
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates) self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset self.dataset = dataset
self.dataset_path = dataset_path self.dataset_path = dataset_path
self.feature_type = feature_type self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch self.seconds_per_batch = seconds_per_batch
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
self.nan_loss_count = 0.0
def _step(self, batch: Batch, batch_idx, step_type): def _step(self, batch: Batch, batch_idx, step_type):
if batch is None: if batch is None:
return None return None, None
waveforms, labels, audio_lengths = batch waveforms, labels, audio_lengths = batch
logit_m, logit_u, feature_penalty = self.model( if step_type == "val":
waveforms, with torch.no_grad():
labels, logit_m, logit_u, feature_penalty = self.model(
audio_lengths, waveforms,
) labels,
audio_lengths,
)
else:
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
loss = self.loss(logit_m, logit_u, feature_penalty) loss = self.loss(logit_m, logit_u, feature_penalty)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) if not torch.isinf(loss) and not torch.isnan(loss):
return loss self.log(f"{step_type}_loss", loss.item() / logit_m.size(0), on_step=True, on_epoch=True)
else:
self.nan_loss_count += 1
self.log("nan_loss_count", self.nan_loss_count, on_step=True, on_epoch=True)
# log accuracies of masked and unmasked frames
correct_m, count_m = _compute_accuracy(logit_m)
correct_u, count_u = _compute_accuracy(logit_u)
self.mask_stats[step_type]["correct"] += correct_m
self.mask_stats[step_type]["count"] += count_m
self.unmask_stats[step_type]["correct"] += correct_u
self.unmask_stats[step_type]["count"] += count_u
self.log(
f"{step_type}_masked_accuracy",
self.mask_stats[step_type]["correct"] / self.mask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
self.log(
f"{step_type}_unmasked_accuracy",
self.unmask_stats[step_type]["correct"] / self.unmask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
return loss, logit_m.size(0)
def configure_optimizers(self): def configure_optimizers(self):
return ( return (
...@@ -160,20 +226,68 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -160,20 +226,68 @@ class HuBERTPreTrainModule(LightningModule):
) )
def training_step(self, batch: Batch, batch_idx): def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train") """Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / num_frames), where num_frames is
the sum of masked frames across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / num_frames.
- Update parameters on each GPU.
Doing so allows us to account for the variability in number of masked frames in
variable-length sequential data.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
loss, num_frame = self._step(batch, batch_idx, "train")
if torch.isinf(loss) or torch.isnan(loss):
opt.zero_grad()
return None
# normalize the loss based on the sum of num_frame across all GPUs
num_frames = self.all_gather(num_frame)
self.log("Gathered number of frames", num_frames.float().sum(), on_step=True, on_epoch=True)
loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
if self.clip_norm is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
return loss
def validation_step(self, batch: Batch, batch_idx): def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")[0]
def on_validation_end(self):
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
def train_dataloader(self): def train_dataloader(self):
dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train") dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler( sampler = BucketizeBatchSampler(
dataset.len_list, dataset.len_list,
num_buckets=10000, num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000, max_token_count=self.seconds_per_batch * 16000,
min_len=32000, min_len=32000,
max_len=250000, max_len=250000,
shuffle=True, shuffle=False,
) )
sampler = DistributedBatchSampler(sampler, shuffle=True) sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch) sampler.set_epoch(self.current_epoch)
...@@ -217,6 +331,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -217,6 +331,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob: float, mask_prob: float,
mask_channel_prob: float, mask_channel_prob: float,
mask_channel_length: float, mask_channel_length: float,
num_classes: int,
aux_num_out: int, aux_num_out: int,
checkpoint: str, checkpoint: str,
dataset_path: str, dataset_path: str,
...@@ -243,6 +358,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -243,6 +358,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob, mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob, mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes,
) )
elif model_name == "hubert_large": elif model_name == "hubert_large":
self.model = torchaudio.models.hubert_pretrain_large( self.model = torchaudio.models.hubert_pretrain_large(
...@@ -254,6 +370,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -254,6 +370,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob, mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob, mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes,
) )
elif model_name == "hubert_xlarge": elif model_name == "hubert_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge( self.model = torchaudio.models.hubert_pretrain_xlarge(
...@@ -265,6 +382,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -265,6 +382,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob, mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob, mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length, mask_channel_length=mask_channel_length,
num_classes=num_classes,
) )
else: else:
raise ValueError(f"Unsupported model name: {model_name}.") raise ValueError(f"Unsupported model name: {model_name}.")
...@@ -274,7 +392,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -274,7 +392,7 @@ class HuBERTFineTuneModule(LightningModule):
p.requires_grad = False p.requires_grad = False
self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True) self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
self.optimizer = torch.optim.AdamW( self.optimizer = torch.optim.AdamW(
list(self.aux.parameters()) + list(self.model.wav2vec2.encoder.parameters()), list(self.aux.parameters()) + list(self.model.parameters()),
lr=learning_rate, lr=learning_rate,
betas=betas, betas=betas,
eps=adam_eps, eps=adam_eps,
...@@ -285,16 +403,18 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -285,16 +403,18 @@ class HuBERTFineTuneModule(LightningModule):
self.dataset_path = dataset_path self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch self.seconds_per_batch = seconds_per_batch
self.subset = subset self.subset = subset
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
def _load_checkpoint(self, checkpoint): def _load_checkpoint(self, checkpoint):
# load pretrain model # load pretrain model from checkpoint
state_dict = torch.load(checkpoint, map_location=torch.device("cpu")) state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
s = {} s = {}
for k in state_dict: for k in state_dict:
if "wav2vec2" in k: if "model." in k:
s[k.replace("model.wav2vec2.", "")] = state_dict[k] s[k.replace("model.", "")] = state_dict[k]
self.model.wav2vec2.load_state_dict(s) self.model.load_state_dict(s)
def _step(self, batch: Batch_FineTune, batch_idx, step_type): def _step(self, batch: Batch_FineTune, batch_idx, step_type):
if batch is None: if batch is None:
...@@ -315,8 +435,6 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -315,8 +435,6 @@ class HuBERTFineTuneModule(LightningModule):
x, _ = self.model.mask_generator(x, padding_mask) x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
logits = self.aux(x) logits = self.aux(x)
logits[padding_mask][..., 0] = 0
logits[padding_mask][..., 1:] = float("-inf")
log_probs = F.log_softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1)
log_probs = log_probs.transpose(0, 1) log_probs = log_probs.transpose(0, 1)
loss = self.loss_fn( loss = self.loss_fn(
...@@ -325,7 +443,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -325,7 +443,7 @@ class HuBERTFineTuneModule(LightningModule):
out_len, out_len,
label_lengths, label_lengths,
) )
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) self.log(f"{step_type}_loss", loss.item() / waveforms.size(0), on_step=True, on_epoch=True)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
...@@ -339,7 +457,47 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -339,7 +457,47 @@ class HuBERTFineTuneModule(LightningModule):
) )
def training_step(self, batch: Batch_FineTune, batch_idx): def training_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "train") """Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
loss = self._step(batch, batch_idx, "train")
# normalize the loss based on the sum of batch_sie across all GPUs
batch_size = batch[0].size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
def validation_step(self, batch: Batch_FineTune, batch_idx): def validation_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
......
...@@ -60,6 +60,12 @@ def _parse_args(): ...@@ -60,6 +60,12 @@ def _parse_args():
type=int, type=int,
help="The number of clusters for KMeans clustering.", help="The number of clusters for KMeans clustering.",
) )
parser.add_argument(
"--percent",
default=-1,
type=float,
help="The percent of data for KMeans clustering. If negative, use all data. (Default: -1)",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -114,6 +120,7 @@ def main(args): ...@@ -114,6 +120,7 @@ def main(args):
args.num_rank, args.num_rank,
km_dir, km_dir,
args.num_cluster, args.num_cluster,
args.percent,
) )
# Predict labels for MFCC or HuBERT features # Predict labels for MFCC or HuBERT features
......
...@@ -7,11 +7,12 @@ python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classe ...@@ -7,11 +7,12 @@ python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classe
import logging import logging
import pathlib import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Optional, Tuple from typing import Tuple
from lightning import HuBERTPreTrainModule from lightning import HuBERTPreTrainModule
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -25,10 +26,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): ...@@ -25,10 +26,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
def run_train(args): def run_train(args):
seed_everything(1337)
checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}" checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}"
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
monitor="Losses/val_loss", monitor="val_loss",
mode="min", mode="min",
save_top_k=5, save_top_k=5,
save_weights_only=False, save_weights_only=False,
...@@ -36,7 +38,7 @@ def run_train(args): ...@@ -36,7 +38,7 @@ def run_train(args):
) )
train_checkpoint = ModelCheckpoint( train_checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
monitor="Losses/train_loss", monitor="train_loss",
mode="min", mode="min",
save_top_k=5, save_top_k=5,
save_weights_only=False, save_weights_only=False,
...@@ -54,7 +56,6 @@ def run_train(args): ...@@ -54,7 +56,6 @@ def run_train(args):
accelerator="gpu", accelerator="gpu",
strategy="ddp", strategy="ddp",
replace_sampler_ddp=False, replace_sampler_ddp=False,
gradient_clip_val=args.clip_norm,
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
) )
...@@ -71,6 +72,7 @@ def run_train(args): ...@@ -71,6 +72,7 @@ def run_train(args):
betas=args.betas, betas=args.betas,
eps=args.eps, eps=args.eps,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
clip_norm=args.clip_norm,
warmup_updates=args.warmup_updates, warmup_updates=args.warmup_updates,
max_updates=args.max_updates, max_updates=args.max_updates,
) )
...@@ -90,7 +92,7 @@ def _parse_args(): ...@@ -90,7 +92,7 @@ def _parse_args():
) )
parser.add_argument( parser.add_argument(
"--resume-checkpoint", "--resume-checkpoint",
type=Optional[pathlib.Path], type=pathlib.Path,
default=None, default=None,
help="Path to the feature and label directories. (Default: None)", help="Path to the feature and label directories. (Default: None)",
) )
...@@ -159,9 +161,9 @@ def _parse_args(): ...@@ -159,9 +161,9 @@ def _parse_args():
) )
parser.add_argument( parser.add_argument(
"--clip-norm", "--clip-norm",
default=None, default=10.0,
type=Optional[float], type=float,
help="The gradient norm value to clip. (Default: None)", help="The gradient norm value to clip. (Default: 10.0)",
) )
parser.add_argument( parser.add_argument(
"--num-nodes", "--num-nodes",
......
...@@ -60,12 +60,6 @@ def extract_feature_mfcc( ...@@ -60,12 +60,6 @@ def extract_feature_mfcc(
).to(device) ).to(device)
waveform = waveform[0].to(device) waveform = waveform[0].to(device)
mfccs = feature_extractor(waveform) # (freq, time) mfccs = feature_extractor(waveform) # (freq, time)
# mfccs = torchaudio.compliance.kaldi.mfcc(
# waveform=waveform,
# sample_frequency=sample_rate,
# use_energy=False,
# ) # (time, freq)
# mfccs = mfccs.transpose(0, 1) # (freq, time)
deltas = torchaudio.functional.compute_deltas(mfccs) deltas = torchaudio.functional.compute_deltas(mfccs)
ddeltas = torchaudio.functional.compute_deltas(deltas) ddeltas = torchaudio.functional.compute_deltas(deltas)
concat = torch.cat([mfccs, deltas, ddeltas], dim=0) concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
......
...@@ -20,12 +20,14 @@ def load_feature( ...@@ -20,12 +20,14 @@ def load_feature(
feat_dir: Path, feat_dir: Path,
split: str, split: str,
num_rank: int, num_rank: int,
percent: float,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
r"""Loading features from pre-saved `.pt` files. r"""Loading features from pre-saved `.pt` files.
Args: Args:
feat_dir (Path): The directory that stores the feature files. feat_dir (Path): The directory that stores the feature files.
split (str): The split of data. Options: [``train``, ``valid``]. split (str): The split of data. Options: [``train``, ``valid``].
num_rank (int): The number of ranks for multi-processing in feature extraction. num_rank (int): The number of ranks for multi-processing in feature extraction.
percent (float): The percent of data for training k-means model. If negative, use all data for training.
Returns: Returns:
(Tensor, Tensor) (Tensor, Tensor)
...@@ -37,9 +39,23 @@ def load_feature( ...@@ -37,9 +39,23 @@ def load_feature(
for rank in range(1, num_rank + 1): for rank in range(1, num_rank + 1):
feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank) feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank)
feat = torch.load(feat_path) feat = torch.load(feat_path)
length = torch.load(len_path) length = torch.load(len_path).int()
feats.append(feat) if percent < 0:
lens.append(length) feats.append(feat)
lens.append(length)
else:
offsets = [0] + torch.cumsum(length, dim=0, dtype=torch.int).tolist()
nsample = int(length.shape[0] * percent)
indices = torch.randperm(length.shape[0])[0:nsample]
indices = torch.sort(indices)[0]
mask = []
for i in range(indices.shape[0]):
index = indices[i]
mask += list(range(offsets[index], offsets[index] + length[index]))
mask = torch.tensor(mask, dtype=torch.int)
feat = torch.index_select(feat, 0, mask)
feats.append(feat)
lens.append(length[indices])
feats = torch.cat(feats) feats = torch.cat(feats)
lens = torch.cat(lens) lens = torch.cat(lens)
return feats, lens return feats, lens
...@@ -51,6 +67,7 @@ def learn_kmeans( ...@@ -51,6 +67,7 @@ def learn_kmeans(
num_rank: int, num_rank: int,
km_dir: Path, km_dir: Path,
n_clusters: int, n_clusters: int,
percent: float = -1,
init: str = "k-means++", init: str = "k-means++",
max_iter: int = 100, max_iter: int = 100,
batch_size: int = 10000, batch_size: int = 10000,
...@@ -66,6 +83,8 @@ def learn_kmeans( ...@@ -66,6 +83,8 @@ def learn_kmeans(
num_rank (int): The number of ranks for multi-processing in feature extraction. num_rank (int): The number of ranks for multi-processing in feature extraction.
km_dir (Path): The directory to store the KMeans clustering model. km_dir (Path): The directory to store the KMeans clustering model.
n_clusters (int): The number of clusters. n_clusters (int): The number of clusters.
percent (float): The percent of data for training k-means model.
If negative, use all data for training. (Default: -1)
init (str, optional): Method for initialization. Options: [``k-means++``, ``random``]. init (str, optional): Method for initialization. Options: [``k-means++``, ``random``].
(Default: ``k-means++``) (Default: ``k-means++``)
max_iter (int, optional): Maximum number of iterations over the complete dataset. (Default: 100) max_iter (int, optional): Maximum number of iterations over the complete dataset. (Default: 100)
...@@ -102,6 +121,7 @@ def learn_kmeans( ...@@ -102,6 +121,7 @@ def learn_kmeans(
feat_dir, feat_dir,
split, split,
num_rank, num_rank,
percent,
) )
feats = feats.numpy() feats = feats.numpy()
km_model.fit(feats) km_model.fit(feats)
...@@ -157,19 +177,16 @@ def get_km_label( ...@@ -157,19 +177,16 @@ def get_km_label(
km_path = _get_model_path(km_dir) km_path = _get_model_path(km_dir)
label_path = label_dir / f"label_{split}.pt" label_path = label_dir / f"label_{split}.pt"
apply_kmeans = ApplyKmeans(km_path, device) apply_kmeans = ApplyKmeans(km_path, device)
feats, lens = load_feature(
feat_dir,
split,
num_rank,
)
feats = feats
lens = lens.long()
offset = 0
assert feats.shape[0] == lens.sum()
with open(label_path, "w") as f: with open(label_path, "w") as f:
for i in range(lens.shape[0]): for rank in range(1, num_rank + 1):
feat = feats[offset : offset + lens[i]].to(device) offset = 0
offset += lens[i] feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank)
label = apply_kmeans(feat).tolist() feats = torch.load(feat_path)
f.write(" ".join(map(str, label)) + "\n") length = torch.load(len_path).int()
assert feats.shape[0] == length.sum()
labels = apply_kmeans(feats.to(device)).tolist()
for i in range(length.shape[0]):
label = labels[offset : offset + length[i]]
offset += length[i]
f.write(" ".join(map(str, label)) + "\n")
_LG.info("Finished predicting labels successfully") _LG.info("Finished predicting labels successfully")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment