Commit ab5edfcd authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add HuBERT fine-tuning recipe (#2352)

Summary:
The PR contains the CTC fine-tuning recipe of HuBERT Base model.
The files include:
- lightning module
- training script
- README and the result table
- evaluation scripts

Pull Request resolved: https://github.com/pytorch/audio/pull/2352

Reviewed By: hwangjeff

Differential Revision: D36915712

Pulled By: nateanl

fbshipit-source-id: 0249635ad5e81a8aa2d228c1d5fe84d78b62a15b
parent 4c19e2cb
# HuBERT Pre-training Example # HuBERT Pre-training and Fine-tuning Examples
This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447). This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447).
## Usage ## Pre-training Usage
The Base architecture of HuBERT model requires two iterations of pre-training. The Base architecture of HuBERT model requires two iterations of pre-training.
### Pre-processing (1st iteration) ### Pre-processing (1st iteration)
...@@ -21,7 +21,7 @@ The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 8 ...@@ -21,7 +21,7 @@ The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 8
Sample SLURM command for the first iteration of pre-training: Sample SLURM command for the first iteration of pre-training:
``` ```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4 srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --dataset-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
``` ```
### Pre-processing (2nd iteration) ### Pre-processing (2nd iteration)
...@@ -37,5 +37,47 @@ The second iteration is trained for 400k steps. ...@@ -37,5 +37,47 @@ The second iteration is trained for 400k steps.
Sample SLURM command for the second iteration of pre-training: Sample SLURM command for the second iteration of pre-training:
``` ```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4 srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --dataset-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
``` ```
## Fine-tuning Usage
After finishing the pre-training step, the model can be validated by fine-tuning on the `LibriLightLimited` dataset (the supervised subset of [Libri-Light](https://github.com/facebookresearch/libri-light) dataset) with an extra feed-forward layer on top of the transformer layers.
During the whole fine-tuning process, the feature extraction layers are frozen (i.e., no gradients is back propagated to these layers). For the first 10k fine-tuning iterations, the transformer layers are frozen and only the CTC layer is trained. After 10k iterations, the transformer layers are fine-tuned along with the CTC layer.
Sample SLURM command for fine-tuning on `10h` subset of `LibriLightLimited` dataset:
```
srun --gpus-per-node=1 -N 1 --ntasks-per-node=1 --cpus-per-task=10 \
python finetune.py --dataset-path /root/datasets/ --exp-dir ./exp_finetune \
--checkpoint /exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--gpus 1 --debug --warmup-updates 2000 --hold-updates 8000 --decay-updates 10000 --max-updates 20000 --learning-rate 5e-5
```
# Decoding
### Viterbi Decoding
The output of CTC layer contains repeated letters, blank symbol ("-"), and silence symbol ("|"). Viterbi decoding unifies the repeated letters into a single letter, removes the blank symbol, and splits the string into a list of words by the silence symbol.
Sample SLURM command for evaluation with Viterbi decoding:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean
```
### CTC Decoding with language model
torchaudio provides a CTCDecoder feature that is based on [Flashlight](https://github.com/flashlight/flashlight). The decoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model.
Sample SLURM command for evaluation with KenLM language model:
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
```
### WER results
The table below contains WER results for fine-tuning HuBERT Base model on `10h` subset of `LibriLightLimited` dataset.
| | WER% (Viterbi)| WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean | 10.7 | 4.4 |
| dev-other | 18.3 | 9.7 |
| test-clean | 10.8 | 4.4 |
| test-other | 18.5 | 10.1 |
from .hubert_dataset import BucketizeBatchSampler, CollateFnHubert, HuBERTDataSet from .hubert_dataset import (
_get_lengths_librilightlimited,
_get_lengths_librispeech,
BucketizeBatchSampler,
CollateFnHubert,
CollateFnLibriLightLimited,
HuBERTDataSet,
)
__all__ = [ __all__ = [
"_get_lengths_librilightlimited",
"_get_lengths_librispeech",
"BucketizeBatchSampler", "BucketizeBatchSampler",
"CollateFnHubert", "CollateFnHubert",
"CollateFnLibriLightLimited",
"HuBERTDataSet", "HuBERTDataSet",
] ]
import math import math
import os
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
...@@ -9,6 +12,9 @@ import torchaudio ...@@ -9,6 +12,9 @@ import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler from torch.utils.data import BatchSampler, Dataset, DistributedSampler
sys.path.append("..")
from utils import _get_label2id
class BucketizeBatchSampler(BatchSampler): class BucketizeBatchSampler(BatchSampler):
"""Buketized BatchSampler for sequential data with different lengths to reduce number of paddings. """Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
...@@ -407,3 +413,68 @@ class CollateFnHubert: ...@@ -407,3 +413,68 @@ class CollateFnHubert:
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
lengths = torch.tensor(lengths) lengths = torch.tensor(lengths)
return waveforms, labels, lengths return waveforms, labels, lengths
def _get_lengths_librilightlimited(files: List[str]) -> List[int]:
lengths = []
for file_path, fileid in files:
speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio
file_audio = f"{speaker_id}-{chapter_id}-{utterance_id}.flac"
file_audio = os.path.join(file_path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames
lengths.append(length)
return lengths
def _get_lengths_librispeech(files: List[str], path: str, ext_audio: str) -> List[int]:
lengths = []
for file_path in files:
speaker_id, chapter_id, utterance_id = file_path.split("-")
fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
length = torchaudio.info(file_audio).num_frames
lengths.append(length)
return lengths
class CollateFnLibriLightLimited:
"""The collate class for LibriSpeech or LibriLightLimited dataset."""
def __call__(self, batch: List[Tuple[Tensor, int, str, int, int, int]]) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
batch (List(Tuple(Tensor, int, str, int, int, int))):
The list of tuples that contains
waveform, sample_rate, transcript, speaker_id, chapter_id, and utterance_id.
Returns:
(Tuple(Tensor, Tensor, Tensor, Tensor)):
The Tensor of waveforms with dimensions `(batch, time)`.
The Tensor of labels with dimensions `(batch, seq)`.
The Tensor of audio lengths with dimensions `(batch,)`.
The Tensor of length lengths with dimensions `(batch,)`.
"""
audio_sizes = [sample[0].shape[1] for sample in batch]
audio_size = max(audio_sizes)
waveforms, labels, audio_lengths, label_lengths = [], [], [], []
label2id = _get_label2id()
for sample in batch:
waveform, transcript = sample[0], sample[2]
label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()])
audio_length = waveform.size(1)
label_length = label.size(0)
waveforms.append(waveform)
audio_lengths.append(audio_length)
label_lengths.append(label_length)
labels.append(label)
data = torch.zeros(len(batch), audio_size)
for i in range(len(waveforms)):
data[i][0 : waveforms[i].shape[1]] = waveforms[i]
audio_lengths = torch.tensor(audio_lengths)
label_lengths = torch.tensor(label_lengths)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)
return data, labels.int(), audio_lengths.int(), label_lengths.int()
import argparse
import logging
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.models.decoder import ctc_decoder, CTCDecoder, download_pretrained_files
from utils import _get_id2label
logger = logging.getLogger(__name__)
def _load_checkpoint(checkpoint: str) -> torch.nn.Module:
model = torchaudio.models.hubert_base(aux_num_out=29)
checkpoint = torch.load(checkpoint, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for k in state_dict:
if "model.wav2vec2" in k:
new_state_dict[k.replace("model.wav2vec2.", "")] = state_dict[k]
elif "aux" in k:
new_state_dict[k] = state_dict[k]
model.load_state_dict(new_state_dict)
return model
def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]:
"""Run greedy decoding for ctc outputs.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
id2token (Dictionary): The dictionary that maps indices of emission's last dimension
to the corresponding tokens.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = F.log_softmax(emission, dim=-1)
hypothesis = hypothesis.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ")
return hypothesis.split()
def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with a KenLM language model.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
decoder (CTCDecoder): The initialized CTCDecoder.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words
return hypothesis
def run_inference(args):
# Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint)
model.eval()
if args.use_lm:
# get decoder files
files = download_pretrained_files("librispeech-4-gram")
decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_score,
sil_score=args.sil_score,
log_add=False,
)
else:
id2token = _get_id2label()
dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split)
total_edit_distance = 0
total_length = 0
for idx, sample in enumerate(dataset):
waveform, _, transcript, _, _, _ = sample
transcript = transcript.strip().lower().strip().replace("\n", "")
with torch.inference_mode():
emission, _ = model(waveform)
if args.use_lm:
hypothesis = _ctc_decode(emission, decoder)
else:
hypothesis = _viterbi_decode(emission, id2token)
total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), hypothesis)
total_length += len(transcript.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--librispeech-path",
type=str,
help="Folder where LibriSpeech dataset is stored.",
)
parser.add_argument(
"--split",
type=str,
choices=["dev-clean", "dev-other", "test-clean", "test-other"],
help="LibriSpeech dataset split. (Default: 'test-clean')",
default="test-clean",
)
parser.add_argument(
"--checkpoint",
type=str,
help="The checkpoint path of fine-tuned HuBERTPretrainModel.",
)
parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.")
parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.")
parser.add_argument(
"--beam-size",
type=int,
default=1500,
help="Beam size for determining number of hypotheses to store. (Default: 1500)",
)
parser.add_argument(
"--beam-size-token",
type=Optional[int],
default=None,
help="Number of tokens to consider at each beam search step. (Default: None)",
)
parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
)
parser.add_argument(
"--lm-weight",
type=float,
default=2.46,
help="Languge model weight in decoding. (Default: 2.46)",
)
parser.add_argument(
"--word-score",
type=float,
default=-0.59,
help="Word insertion score in decoding. (Default: -0.59)",
)
parser.add_argument(
"--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
)
parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def _main():
args = _parse_args()
_init_logger(args.debug)
run_inference(args)
if __name__ == "__main__":
_main()
#!/usr/bin/env python3
"""Fine-tune the HuBERTPretrainModel on 10 hours of LibriLightLimited dataset.
Example:
python finetune.py --dataset-path ./root/datasets/ --exp-dir ./exp_finetune \
--checkpoint /exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--gpus 1 --debug --warmup-updates 2000 --hold-updates 8000 --decay-updates 10000 \
--max-updates 20000 --learning-rate 5e-5
"""
import logging
import pathlib
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
RawDescriptionHelpFormatter,
)
from typing import Tuple
from lightning import HuBERTFineTuneModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
logger = logging.getLogger(__name__)
class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
# To use ArgumentDefaultsHelpFormatter as the formatter_class and
# RawDescriptionHelpFormatter to add custom formatting to description or epilog.
# Check: https://stackoverflow.com/a/18462760
pass
def run_train(args):
checkpoint_dir = args.exp_dir / f"checkpoints_{args.model_name}"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
accumulate_grad_batches=args.accumulate_grad_batches,
)
model = HuBERTFineTuneModule(
model_name=args.model_name,
encoder_projection_dropout=args.encoder_projection_dropout,
encoder_attention_dropout=args.encoder_attention_dropout,
encoder_ff_interm_dropout=args.encoder_ff_interm_dropout,
encoder_dropout=args.encoder_dropout,
encoder_layer_drop=args.encoder_layer_drop,
mask_prob=args.mask_prob,
mask_channel_prob=args.mask_channel_prob,
mask_channel_length=args.mask_channel_length,
aux_num_out=args.aux_num_out,
checkpoint=args.checkpoint,
dataset_paths=args.dataset_path,
seconds_per_batch=args.seconds_per_batch,
subset=args.subset,
learning_rate=args.learning_rate,
betas=args.betas,
adam_eps=args.adam_eps,
weight_decay=args.weight_decay,
freeze_encoder_updates=args.freeze_encoder_updates,
warmup_updates=args.warmup_updates,
hold_updates=args.hold_updates,
decay_updates=args.decay_updates,
)
trainer.fit(model)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=_Formatter,
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
required=True,
help="Path to the LibriSpeech and LibriLightLimited datasets.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp_finetune"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp_finetune')",
)
parser.add_argument(
"--model-name",
default="hubert_pretrain_base",
choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"],
type=str,
help="The HuBERTPretrainModel to fine-tune. (Default: 'hubert_pretrain_base')",
)
parser.add_argument(
"--encoder-projection-dropout",
default=0.0,
type=float,
help="The dropout probability applied after the input feature "
"is projected to ``encoder_embed_dim``. (Default: 0.0)",
)
parser.add_argument(
"--encoder-attention-dropout",
default=0.0,
type=float,
help="The dropout probability applied after softmax in self-attention layer." "(Default: 0.0)",
)
parser.add_argument(
"--encoder-ff-interm-dropout",
default=0.1,
type=float,
help="The dropout probability applied in feedforward layer." "(Default: 0.1)",
)
parser.add_argument(
"--encoder-dropout",
default=0.0,
type=float,
help="The dropout probability applied at the end of feed forward layer." "(Default: 0.0)",
)
parser.add_argument(
"--encoder-layer-drop",
default=0.1,
type=float,
help="Probability to drop each encoder layer during training. (Default: 0.1)",
)
parser.add_argument(
"--mask-prob",
default=0.65,
type=float,
help="Probability to mask the frames of the convolutional layer feature." "(Default: 0.75)",
)
parser.add_argument(
"--mask-channel-prob",
default=0.5,
type=float,
help="Probability to mask the feature dimension of the convolutional layer feature." "(Default: 0.5)",
)
parser.add_argument(
"--mask-channel-length",
default=64,
type=int,
help="Minimum space between spans (if no overlap is enabled) for channel masking." "(Default: 64)",
)
parser.add_argument(
"--accumulate-grad-batches",
default=1,
type=int,
help="Number of batches to accumulate the gradients during training. (Default: 1)",
)
parser.add_argument(
"--aux-num-out",
default=29,
type=int,
help="The dimension of linear layer for CTC training. (Default: 29)",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the pre-trained HuBERTPretrainModel checpoint.",
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, help="The learning rate of Adam optimizer. (Default: 2e-5)"
)
parser.add_argument(
"--betas",
default=(0.9, 0.98),
type=Tuple,
help="The coefficients for computing running averages of gradient and its square (Default: (0.9, 0.98))",
)
parser.add_argument(
"--adam-eps",
default=1e-8,
type=float,
help="Epsilon value in Adam optimizer. (Default: 1e-8)",
)
parser.add_argument(
"--weight-decay",
default=1e-6,
type=float,
help="Weight decay (L2 penalty) (Default: 0.0)",
)
parser.add_argument(
"--num_nodes",
default=1,
type=int,
help="Number of nodes to use for fine-tuning. (Default: 1)",
)
parser.add_argument(
"--gpus",
default=1,
type=int,
help="Number of GPUs per node to use for fine-tuning. (Default: 1)",
)
parser.add_argument(
"--freeze-encoder-updates",
default=10000,
type=int,
help="Number of steps to freeze the transformer encoder in HuBERT. (Default: 10000)",
)
parser.add_argument(
"--warmup-updates",
default=2000,
type=int,
help="Number of steps for warm up the learning rate. (Default: 8000)",
)
parser.add_argument(
"--hold-updates",
default=8000,
type=int,
help="Number of steps for keeping the peak learning rate. (Default: 0)",
)
parser.add_argument(
"--decay-updates",
default=10000,
type=int,
help="Number of steps for decreasing the learning rate. (Default: 72000)",
)
parser.add_argument(
"--max-updates",
default=20000,
type=int,
help="Total number of training steps. (Default: 250000)",
)
parser.add_argument(
"--seconds-per-batch",
default=200,
type=float,
help="Number of seconds of audio in a mini-batch. (Default: 200)",
)
parser.add_argument(
"--subset",
default="10h",
type=str,
choices=["10min", "1h", "10h"],
help="The subset of LibriLightLimited dataset for fine-tuning. (Default: '10h')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_train(args)
if __name__ == "__main__":
cli_main()
import math
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
import torchaudio.models.wav2vec2.components as components
from dataset import ( from dataset import (
_get_lengths_librilightlimited,
_get_lengths_librispeech,
BucketizeBatchSampler, BucketizeBatchSampler,
CollateFnHubert, CollateFnHubert,
CollateFnLibriLightLimited,
DistributedBatchSampler, DistributedBatchSampler,
HuBERTDataSet, HuBERTDataSet,
) )
...@@ -16,6 +22,7 @@ from torch.utils.data import DataLoader ...@@ -16,6 +22,7 @@ from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor, Tensor] Batch = Tuple[Tensor, Tensor, Tensor]
Batch_FineTune = Tuple[Tensor, Tensor, Tensor, Tensor]
class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler): class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
...@@ -43,6 +50,50 @@ class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler): ...@@ -43,6 +50,50 @@ class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
return [base_lr * pct_remaining for base_lr in self.base_lrs] return [base_lr * pct_remaining for base_lr in self.base_lrs]
class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear learning rate scheduler with warmup, hold, and decay."""
def __init__(
self,
optimizer: Optimizer,
warmup_updates: int,
hold_updates: int,
decay_updates: int,
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.05,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.hold_updates = hold_updates
self.decay_updates = decay_updates
self.init_lr_scale = init_lr_scale
self.final_lr_scale = final_lr_scale
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count <= self.warmup_updates:
return [
base_lr * (self.init_lr_scale + self._step_count / self.warmup_updates * (1 - self.init_lr_scale))
for base_lr in self.base_lrs
]
elif self.warmup_updates < self._step_count <= (self.warmup_updates + self.hold_updates):
return list(self.base_lrs)
elif self._step_count <= (self.warmup_updates + self.hold_updates + self.decay_updates):
return [
base_lr
* math.exp(
math.log(self.final_lr_scale)
* (self._step_count - self.warmup_updates - self.hold_updates)
/ self.decay_updates
)
for base_lr in self.base_lrs
]
else:
return [base_lr * self.final_lr_scale for base_lr in self.base_lrs]
class HuBERTPreTrainModule(LightningModule): class HuBERTPreTrainModule(LightningModule):
def __init__( def __init__(
self, self,
...@@ -51,7 +102,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -51,7 +102,7 @@ class HuBERTPreTrainModule(LightningModule):
feature_grad_mult: float, feature_grad_mult: float,
num_classes: int, num_classes: int,
dataset: str, dataset: str,
root_path: str, dataset_path: str,
feature_type: str, feature_type: str,
seconds_per_batch: float, seconds_per_batch: float,
learning_rate: float, learning_rate: float,
...@@ -80,7 +131,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -80,7 +131,7 @@ class HuBERTPreTrainModule(LightningModule):
) )
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates) self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset self.dataset = dataset
self.root_path = root_path self.dataset_path = dataset_path
self.feature_type = feature_type self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch self.seconds_per_batch = seconds_per_batch
...@@ -115,7 +166,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -115,7 +166,7 @@ class HuBERTPreTrainModule(LightningModule):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
def train_dataloader(self): def train_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "train") dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler( sampler = BucketizeBatchSampler(
dataset.len_list, dataset.len_list,
num_buckets=10000, num_buckets=10000,
...@@ -135,7 +186,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -135,7 +186,7 @@ class HuBERTPreTrainModule(LightningModule):
return dataloader return dataloader
def val_dataloader(self): def val_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "valid") dataset = HuBERTDataSet(self.dataset_path, self.dataset, "valid")
sampler = BucketizeBatchSampler( sampler = BucketizeBatchSampler(
dataset.len_list, dataset.len_list,
num_buckets=1000, num_buckets=1000,
...@@ -151,3 +202,174 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -151,3 +202,174 @@ class HuBERTPreTrainModule(LightningModule):
num_workers=10, num_workers=10,
) )
return dataloader return dataloader
class HuBERTFineTuneModule(LightningModule):
def __init__(
self,
*,
model_name: str,
encoder_projection_dropout: float,
encoder_attention_dropout: float,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_drop: float,
mask_prob: float,
mask_channel_prob: float,
mask_channel_length: float,
aux_num_out: int,
checkpoint: str,
dataset_path: str,
seconds_per_batch: float,
subset: str,
learning_rate: float,
betas: Tuple[float, float],
adam_eps: float,
weight_decay: float,
freeze_encoder_updates: int,
warmup_updates: int,
hold_updates: int,
decay_updates: int,
):
super().__init__()
if model_name == "hubert_pretrain_base":
self.model = torchaudio.models.hubert_pretrain_base(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
)
elif model_name == "hubert_large":
self.model = torchaudio.models.hubert_pretrain_large(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
)
elif model_name == "hubert_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
)
else:
raise ValueError(f"Unsupported model name: {model_name}.")
self.aux = torch.nn.Linear(768, aux_num_out)
self._load_checkpoint(checkpoint)
for p in self.model.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
self.optimizer = torch.optim.Adam(
list(self.aux.parameters()) + list(self.model.wav2vec2.encoder.parameters()),
lr=learning_rate,
betas=betas,
eps=adam_eps,
weight_decay=weight_decay,
)
self.freeze_encoder_updates = freeze_encoder_updates
self.lr_scheduler = TriStageLRScheduler(self.optimizer, warmup_updates, hold_updates, decay_updates)
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.subset = subset
def _load_checkpoint(self, checkpoint):
# load pretrain model
state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
state_dict = state_dict["state_dict"]
s = {}
for k in state_dict:
if "wav2vec2" in k:
s[k.replace("model.wav2vec2.", "")] = state_dict[k]
self.model.wav2vec2.load_state_dict(s)
def _step(self, batch: Batch_FineTune, batch_idx, step_type):
if batch is None:
return None
waveforms, labels, audio_lengths, label_lengths = batch
if self.global_step <= self.freeze_encoder_updates:
with torch.no_grad():
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
padding_mask = components._get_padding_mask(x, out_len)
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
else:
with torch.no_grad():
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
padding_mask = components._get_padding_mask(x, out_len)
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
logits = self.aux(x)
logits[padding_mask][..., 0] = 0
logits[padding_mask][..., 1:] = float("-inf")
log_probs = F.log_softmax(logits, dim=-1)
log_probs = log_probs.transpose(0, 1)
loss = self.loss_fn(
log_probs,
labels,
out_len,
label_lengths,
)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[
self.optimizer,
],
[
{"scheduler": self.lr_scheduler, "interval": "step"},
],
)
def training_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "val")
def train_dataloader(self):
dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths)
sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.global_step)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnLibriLightLimited(),
num_workers=10,
)
return dataloader
def val_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.dataset_path, "dev-other")
lengths = _get_lengths_librispeech(dataset._walker, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=False
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnLibriLightLimited(),
num_workers=10,
)
return dataloader
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Train the HuBERTPretrainModel by using labels generated by KMeans clustering. """Train the HuBERTPretrainModel by using labels generated by KMeans clustering.
Example: Example:
python train.py --root-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100 python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100
""" """
import logging import logging
...@@ -68,7 +68,7 @@ def run_train(args): ...@@ -68,7 +68,7 @@ def run_train(args):
feature_grad_mult=args.feature_grad_mult, feature_grad_mult=args.feature_grad_mult,
num_classes=args.num_classes, num_classes=args.num_classes,
dataset=args.dataset, dataset=args.dataset,
root_path=args.root_path, dataset_path=args.dataset_path,
feature_type=args.feature_type, feature_type=args.feature_type,
seconds_per_batch=args.seconds_per_batch, seconds_per_batch=args.seconds_per_batch,
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
...@@ -87,7 +87,7 @@ def _parse_args(): ...@@ -87,7 +87,7 @@ def _parse_args():
formatter_class=_Formatter, formatter_class=_Formatter,
) )
parser.add_argument( parser.add_argument(
"--root-path", "--dataset-path",
type=pathlib.Path, type=pathlib.Path,
required=True, required=True,
help="Path to the feature and label directories.", help="Path to the feature and label directories.",
......
from .common_utils import create_tsv from .common_utils import _get_id2label, _get_label2id, create_tsv
from .feature_utils import dump_features from .feature_utils import dump_features
from .kmeans import get_km_label, learn_kmeans from .kmeans import get_km_label, learn_kmeans
__all__ = [ __all__ = [
"create_tsv", "create_tsv",
"_get_id2label",
"_get_label2id",
"dump_features", "dump_features",
"learn_kmeans", "learn_kmeans",
"get_km_label", "get_km_label",
......
...@@ -9,7 +9,7 @@ Data pre-processing: create tsv files for training (and valiation). ...@@ -9,7 +9,7 @@ Data pre-processing: create tsv files for training (and valiation).
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Dict, Tuple, Union
import torch import torch
import torchaudio import torchaudio
...@@ -94,3 +94,17 @@ def _get_model_path(km_dir: Path) -> Path: ...@@ -94,3 +94,17 @@ def _get_model_path(km_dir: Path) -> Path:
Path: The file path of the model. Path: The file path of the model.
""" """
return km_dir / "model.pt" return km_dir / "model.pt"
def _get_id2label() -> Dict:
"""Get the dictionary that maps indices of ASR model's last layer dimension to the corresponding labels."""
bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
labels = bundle.get_labels()
return {i: char.lower() for i, char in enumerate(labels)}
def _get_label2id() -> Dict:
"""Get the dictionary that maps the labels to the corresponding indices in ASR model's last dimension."""
bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
labels = bundle.get_labels()
return {char: i for i, char in enumerate(labels)}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment