Commit ff01be0f authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update ssl example (#3060)

Summary:
- Rename the current `ssl` example to `self_supervised_learning`
- Add README to demonstrate how to run the recipe with hubert task

Pull Request resolved: https://github.com/pytorch/audio/pull/3060

Reviewed By: mthrok

Differential Revision: D43287868

Pulled By: nateanl

fbshipit-source-id: 10352682485ef147ca32f4c4c9f9cde995444aa0
parent 73b29fc9
# Modularized Self-supervised Learning Recipe
This directory contains the modularized training recipe for audio/speech self-supervised learning. The principle is to let users easily inject a new component (model, data_module, loss function, etc) to the existing recipe for different tasks (e.g. Wav2Vec 2.0, HuBERT, etc).
## HuBERT Pre-training Example
To get the K-Means labels for HuBERT pre-training, please check the [pre-processing step](../hubert/README.md#pre-processing-1st-iteration) in hubert example.
In order to run the HuBERT pre-training script for the first iteration, users need to go to `examples` directory and run the following SLURM command:
```
cd examples
srun \
--gpus-per-node=8 \
--ntasks-per-node=8 \
-N 4 \
--cpus-per-task=10 \
python -m self_supervised_learning.train_hubert \
--dataset-path hubert/exp/data/mfcc/ \
--exp-dir self_supervised_learning/exp_iter1 \
--feature-type mfcc \
--num-class 100 \
--max-updates 250000 \
--learning-rate 0.0005 \
--gpus 8 \
--num-nodes 4
```
......@@ -6,10 +6,11 @@ import numpy as np
import torch
import torch.distributed as dist
import torchaudio
from lightning import Batch
from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from ..lightning import Batch
class BucketizeBatchSampler(BatchSampler):
"""Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
......
from _linear_decay import LinearDecayLRScheduler
from ._linear_decay import LinearDecayLRScheduler
__all__ = [
"LinearDecayLRScheduler",
......
......@@ -6,14 +6,15 @@ from typing import Dict, Tuple
import torch
import torchaudio.models
from data_modules import HuBERTDataModule
from lightning import SSLPretrainModule
from losses import hubert_loss
from lr_schedulers import LinearDecayLRScheduler
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from .data_modules import HuBERTDataModule
from .lightning import SSLPretrainModule
from .losses import hubert_loss
from .lr_schedulers import LinearDecayLRScheduler
class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
# To use ArgumentDefaultsHelpFormatter as the formatter_class and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment