#!/usr/bin/env python3 """Train the HuBERTPretrainModel by using labels generated by KMeans clustering. Example: python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100 """ import logging import pathlib from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter from typing import Tuple from lightning import HuBERTPreTrainModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.seed import seed_everything logger = logging.getLogger(__name__) class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): # To use ArgumentDefaultsHelpFormatter as the formatter_class and # RawDescriptionHelpFormatter to add custom formatting to description or epilog. # Check: https://stackoverflow.com/a/18462760 pass def run_train(args): seed_everything(1337) checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}" checkpoint = ModelCheckpoint( checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) train_checkpoint = ModelCheckpoint( checkpoint_dir, monitor="train_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) callbacks = [ checkpoint, train_checkpoint, ] trainer = Trainer( default_root_dir=args.exp_dir, max_steps=args.max_updates, num_nodes=args.num_nodes, gpus=args.gpus, accelerator="gpu", strategy="ddp", replace_sampler_ddp=False, callbacks=callbacks, reload_dataloaders_every_n_epochs=1, ) model = HuBERTPreTrainModule( model_name=args.model_name, feature_grad_mult=args.feature_grad_mult, num_classes=args.num_classes, dataset=args.dataset, dataset_path=args.dataset_path, feature_type=args.feature_type, seconds_per_batch=args.seconds_per_batch, learning_rate=args.learning_rate, betas=args.betas, eps=args.eps, weight_decay=args.weight_decay, clip_norm=args.clip_norm, warmup_updates=args.warmup_updates, max_updates=args.max_updates, ) trainer.fit(model, ckpt_path=args.resume_checkpoint) def _parse_args(): parser = ArgumentParser( description=__doc__, formatter_class=_Formatter, ) parser.add_argument( "--dataset-path", type=pathlib.Path, required=True, help="Path to the feature and label directories.", ) parser.add_argument( "--resume-checkpoint", type=pathlib.Path, default=None, help="Path to the feature and label directories. (Default: None)", ) parser.add_argument( "--feature-type", choices=["mfcc", "hubert"], type=str, required=True, ) parser.add_argument( "--feature-grad-mult", default=0.1, type=float, help="The scaling factor to multiply the feature extractor gradient. (Default: 0.1)", ) parser.add_argument( "--num-classes", choices=[100, 500], type=int, required=True, help="The ``num_class`` when building the hubert_pretrain_base model.", ) parser.add_argument( "--model-name", default="hubert_pretrain_base", choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"], type=str, help="The HuBERT model to train. (Default: 'hubert_pretrain_base')", ) parser.add_argument( "--exp-dir", default=pathlib.Path("./exp"), type=pathlib.Path, help="Directory to save checkpoints and logs to. (Default: './exp')", ) parser.add_argument( "--dataset", default="librispeech", choices=["librispeech", "librilight"], type=str, help="The dataset for training. (Default: 'librispeech')", ) parser.add_argument( "--learning-rate", default=0.0005, type=float, help="The peak learning rate. (Default: 0.0005)", ) parser.add_argument( "--betas", default=(0.9, 0.98), type=Tuple, help="The coefficients for computing running averages of gradient and its square (Default: (0.9, 0.98))", ) parser.add_argument( "--eps", default=1e-6, type=float, help="Epsilon value in Adam optimizer. (Default: 1e-6)", ) parser.add_argument( "--weight-decay", default=0.01, type=float, help="Weight decay (L2 penalty) (default: 0.01)", ) parser.add_argument( "--clip-norm", default=10.0, type=float, help="The gradient norm value to clip. (Default: 10.0)", ) parser.add_argument( "--num-nodes", default=4, type=int, help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( "--gpus", default=8, type=int, help="Number of GPUs per node to use for training. (Default: 8)", ) parser.add_argument( "--warmup-updates", default=32000, type=int, help="Number of steps for warm up the learning rate. (Default: 32000)", ) parser.add_argument( "--max-updates", default=250000, type=int, help="Total number of training steps. (Default: 250000)", ) parser.add_argument( "--seconds-per-batch", default=87.5, type=float, help="Number of seconds of audio in a mini-batch. (Default: 87.5)", ) parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") return parser.parse_args() def _init_logger(debug): fmt = "%(asctime)s %(message)s" if debug else "%(message)s" level = logging.DEBUG if debug else logging.INFO logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") def cli_main(): args = _parse_args() _init_logger(args.debug) run_train(args) if __name__ == "__main__": cli_main()