Commit 48a0c17a authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add recipe for HuBERT model pre-training (#2198)

Summary:
Replace https://github.com/pytorch/audio/issues/2129

Pull Request resolved: https://github.com/pytorch/audio/pull/2198

Reviewed By: carolineechen

Differential Revision: D36544163

Pulled By: nateanl

fbshipit-source-id: 3f19ba5b0f2c2b9e93b0603c3b4491c1dbc40ef8
parent a984872d
# HuBERT Pre-training Example
This directory contains sample implementations of pre-training pipeline for [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447).
## Usage
The Base architecture of HuBERT model requires two iterations of pre-training.
### Pre-processing (1st iteration)
[`preprocess.py`](./preprocess.py) generates the file list of training and validation data, trains a KMeans clustering model with either MFCC feature or the transformer layer's output from the pre-trained HuBERT model, then predict the cluster ID for each utterance as the label for masked prediction training.
Sample SLURM command for the first iteration of pre-preprocessing, which uses MFCC feature to train KMeans model:
```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type mfcc --exp-dir ./exp --num-cluster 100
```
### Pre-training (1st iteration)
[`train.py`](./train.py) trains a HuBERTPretrainModel using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training.
The first iteration is trained for 250k steps on 32 GPUs, each GPU has at most 87.5 seconds of audio in a mini-batch.
Sample SLURM command for the first iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/mfcc/ --exp-dir ./exp_iter1 --feature-type mfcc --num-class 100 --max-updates 250000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```
### Pre-processing (2nd iteration)
After the first iteration of pre-training, the intermediate transformer layer's output of the pre-trained HuBERTPretrainModel can be applied to train a new KMeans clustering model. Then the KMeans clustering model can be used to generate new clustering labels for the second iteration of masked prediction training.
Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance.
```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500
```
### Pre-training (2nd iteration)
The second iteration is trained for 400k steps.
Sample SLURM command for the second iteration of pre-training:
```
srun --gpus-per-node=8 --ntasks-per-node=8 -N 4 --cpus-per-task=10 python train.py --root-path ./exp/data/hubert_6/ --exp-dir ./exp_iter2 --feature-type hubert --num-class 500 --max-updates 400000 --learning-rate 0.0005 --gpus 8 --num-nodes 4
```
from typing import Tuple
import torch
import torchaudio
from dataset import (
BucketizeBatchSampler,
CollateFnHubert,
DistributedBatchSampler,
HuBERTDataSet,
)
from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor, Tensor]
class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear learning rate scheduler with warm up."""
def __init__(
self,
optimizer: Optimizer,
warmup_updates: int,
max_updates: int,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.max_updates = max_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count <= self.warmup_updates:
return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs]
elif self._step_count >= self.max_updates:
return [0.0 for _ in self.base_lrs]
else:
pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates)
return [base_lr * pct_remaining for base_lr in self.base_lrs]
class HuBERTPreTrainModule(LightningModule):
def __init__(
self,
*,
model_name: str,
feature_grad_mult: float,
num_classes: int,
dataset: str,
root_path: str,
feature_type: str,
seconds_per_batch: float,
learning_rate: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
warmup_updates: int,
max_updates: int,
):
super().__init__()
if model_name == "hubert_pretrain_base":
self.model = torchaudio.models.hubert_pretrain_base(
feature_grad_mult=feature_grad_mult, num_classes=num_classes
)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large()
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge()
else:
raise ValueError(f"Unsupported model name: {model_name}")
self.loss = hubert_loss
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
)
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset
self.root_path = root_path
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch
def _step(self, batch: Batch, batch_idx, step_type):
if batch is None:
return None
waveforms, labels, audio_lengths = batch
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
loss = self.loss(logit_m, logit_u, feature_penalty)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{
"scheduler": self.lr_scheduler,
"interval": "step",
},
],
)
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val")
def train_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader
def val_dataloader(self):
dataset = HuBERTDataSet(self.root_path, self.dataset, "valid")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader
from .hubert_loss import hubert_loss
__all__ = [
"hubert_loss",
]
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
def hubert_loss(
logit_m: Optional[Tensor],
logit_u: Optional[Tensor],
feature_penalty: Tensor,
masked_weight: float = 1.0,
unmasked_weight: float = 0.0,
feature_weight: float = 10.0,
reduction: str = "sum",
) -> Tensor:
"""Compute the cross-entropy loss on HuBERT masked and non-masked logits.
Args:
logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`.
logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`.
feature_penalty (Tensor): The feature mean value for additional penalty loss.
masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``).
unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``).
feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).
"""
loss = feature_penalty * feature_weight * logit_m.shape[0]
if logit_m is not None:
target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
loss += loss_m * masked_weight
if logit_u is not None:
target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device)
loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction)
loss += loss_u * unmasked_weight
return loss
#!/usr/bin/env python3
"""Train the HuBERTPretrainModel by using labels generated by KMeans clustering.
Example:
python train.py --root-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100
"""
import logging
import pathlib
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
RawDescriptionHelpFormatter,
)
from typing import Optional, Tuple
from lightning import HuBERTPreTrainModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
logger = logging.getLogger(__name__)
class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
# To use ArgumentDefaultsHelpFormatter as the formatter_class and
# RawDescriptionHelpFormatter to add custom formatting to description or epilog.
# Check: https://stackoverflow.com/a/18462760
pass
def run_train(args):
checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
gradient_clip_val=args.clip_norm,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
model = HuBERTPreTrainModule(
model_name=args.model_name,
feature_grad_mult=args.feature_grad_mult,
num_classes=args.num_classes,
dataset=args.dataset,
root_path=args.root_path,
feature_type=args.feature_type,
seconds_per_batch=args.seconds_per_batch,
learning_rate=args.learning_rate,
betas=args.betas,
eps=args.eps,
weight_decay=args.weight_decay,
warmup_updates=args.warmup_updates,
max_updates=args.max_updates,
)
trainer.fit(model, ckpt_path=args.resume_checkpoint)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=_Formatter,
)
parser.add_argument(
"--root-path",
type=pathlib.Path,
required=True,
help="Path to the feature and label directories.",
)
parser.add_argument(
"--resume-checkpoint",
type=Optional[pathlib.Path],
default=None,
help="Path to the feature and label directories. (Default: None)",
)
parser.add_argument(
"--feature-type",
choices=["mfcc", "hubert"],
type=str,
required=True,
)
parser.add_argument(
"--feature-grad-mult",
default=0.1,
type=float,
help="The scaling factor to multiply the feature extractor gradient. (Default: 0.1)",
)
parser.add_argument(
"--num-classes",
choices=[100, 500],
type=int,
required=True,
help="The ``num_class`` when building the hubert_pretrain_base model.",
)
parser.add_argument(
"--model-name",
default="hubert_pretrain_base",
choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"],
type=str,
help="The HuBERT model to train. (Default: 'hubert_pretrain_base')",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--dataset",
default="librispeech",
choices=["librispeech", "librilight"],
type=str,
help="The dataset for training. (Default: 'librispeech')",
)
parser.add_argument(
"--learning-rate",
default=0.0005,
type=float,
help="The peak learning rate. (Default: 0.0005)",
)
parser.add_argument(
"--betas",
default=(0.9, 0.98),
type=Tuple,
help="The coefficients for computing running averages of gradient and its square (Default: (0.9, 0.98))",
)
parser.add_argument(
"--eps",
default=1e-6,
type=float,
help="Epsilon value in Adam optimizer. (Default: 1e-6)",
)
parser.add_argument(
"--weight-decay",
default=0.01,
type=float,
help="Weight decay (L2 penalty) (default: 0.01)",
)
parser.add_argument(
"--clip-norm",
default=None,
type=Optional[float],
help="The gradient norm value to clip. (Default: None)",
)
parser.add_argument(
"--num-nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--warmup-updates",
default=32000,
type=int,
help="Number of steps for warm up the learning rate. (Default: 32000)",
)
parser.add_argument(
"--max-updates",
default=250000,
type=int,
help="Total number of training steps. (Default: 250000)",
)
parser.add_argument(
"--seconds-per-batch",
default=87.5,
type=float,
help="Number of seconds of audio in a mini-batch. (Default: 87.5)",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_train(args)
if __name__ == "__main__":
cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment