"vscode:/vscode.git/clone" did not exist on "0a0dd34e6a685320b4ecceac1646f4e04c6d39d0"
Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
import random
from typing import List
import sentencepiece as spm
import torch
import torchvision
from data_module import LRS3DataModule
from lightning import Batch
from lightning_av import AVBatch
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class AdaptiveTimeMask(torch.nn.Module):
def __init__(self, window, stride):
super().__init__()
self.window = window
self.stride = stride
def forward(self, x):
cloned = x.clone()
length = cloned.size(1)
n_mask = int((length + self.stride - 0.1) // self.stride)
ts = torch.randint(0, self.window, size=(n_mask, 2))
for t, t_end in ts:
if length - t <= 0:
continue
t_start = random.randrange(0, length - t)
if t_start == t_start + t:
continue
t_end += t_start
cloned[:, t_start:t_end] = 0
return cloned
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[-1].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = []
raw_audios = []
for sample in samples:
if args.modality == "visual":
raw_videos.append(sample[0])
if args.modality == "audio":
raw_audios.append(sample[0])
if args.modality == "audiovisual":
length = min(len(sample[0]) // 640, len(sample[1]))
raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length])
if args.modality == "visual" or args.modality == "audiovisual":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
if args.modality == "audio" or args.modality == "audiovisual":
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.modality == "visual":
return videos, video_lengths
if args.modality == "audio":
return audios, audio_lengths
if args.modality == "audiovisual":
return audios, videos, audio_lengths, video_lengths
class TrainTransform:
def __init__(self, sp_model_path: str, args):
self.args = args
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_video_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x / 255.0),
torchvision.transforms.RandomCrop(88),
torchvision.transforms.RandomHorizontalFlip(0.5),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Grayscale(),
FunctionalModule(lambda x: x.transpose(0, 1)),
AdaptiveTimeMask(10, 25),
torchvision.transforms.Normalize(0.421, 0.165),
)
self.train_audio_pipeline = torch.nn.Sequential(
AdaptiveTimeMask(10, 25),
)
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, sp_model_path: str, args):
self.args = args
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_video_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x / 255.0),
torchvision.transforms.CenterCrop(88),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Grayscale(),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Normalize(0.421, 0.165),
)
self.valid_audio_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x),
)
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.modality == "audio":
audios, audio_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.modality == "visual":
videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.modality == "audiovisual":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, sp_model_path: str, args):
self.val_transforms = ValTransform(sp_model_path, args)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(args, sp_model_path, max_frames=1800):
train_transform = TrainTransform(sp_model_path=sp_model_path, args=args)
val_transform = ValTransform(sp_model_path=sp_model_path, args=args)
test_transform = TestTransform(sp_model_path=sp_model_path, args=args)
return LRS3DataModule(
args=args,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
max_frames=max_frames,
)
# Time-Frequency Mask based DNN MVDR Beamforming Example
This directory contains sample implementations of training and evaluation pipelines for an DNN Beamforming model.
The `DNNBeamformer` model composes the following componenst:
+ [`torchaudio.transforms.Spectrogram`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html#spectrogram) that applies Short-time Fourier Transform (STFT) to the waveform.
+ ConvTasNet without encoder/decoder that predicts T-F masks for speech and noise, respectively.
+ [`torchaudio.transforms.PSD`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.PSD.html#psd) that computes covariance matrices for speech and noise.
+ [`torchaudio.transforms.SoudenMVDR`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.SoudenMVDR.html#soudenmvdr) that estimates the compex-valued STFT for the enhanced speech.
+ [`torchaudio.transforms.InverseSpectrogram`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseSpectrogram.html#inversespectrogram) that applies inverse STFT (iSTFT) to generate the enhanced waveform.
## Usage
### Training
[`train.py`](./train.py) trains a [`DNNBeamformer`](./model.py) model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the [`L3DAS22`](https://www.kaggle.com/datasets/l3dasteam/l3das22) datasets.
### Evaluation
[`eval.py`](./eval.py) evaluates a trained [`DNNBeamformer`](./model.py) on the test subset of L3DAS22 dataset.
### L3DAS22
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --dataset-path ./datasets/L3DAS22 --checkpoint-path ./exp/checkpoints
```
Sample SLURM command for evaluation:
```
srun python eval.py --checkpoint-path ./exp/checkpoints/epoch=97-step=780472.ckpt --dataset-path ./datasets/L3DAS22 --use-cuda
```
Using the sample training command above, [`train.py`](./train.py) produces a model with 5.0M parameters (57.5MB).
The table below contains Ci-SDR, STOI, and PESQ results for the test subset of `L3DAS22` dataset.
| Ci-SDR | STOI | PESQ |
|:-------------------:|-------------:|-------------:|
| 19.00 | 0.82 | 2.46 |
If you find this training recipe useful, please cite as:
```bibtex
@article{yang2021torchaudio,
title={TorchAudio: Building Blocks for Audio and Speech Processing},
author={Yao-Yuan Yang and Moto Hira and Zhaoheng Ni and Anjali Chourdia and Artyom Astafurov and Caroline Chen and Ching-Feng Yeh and Christian Puhrsch and David Pollack and Dmitriy Genzel and Donny Greenberg and Edward Z. Yang and Jason Lian and Jay Mahadeokar and Jeff Hwang and Ji Chen and Peter Goldsborough and Prabhat Roy and Sean Narenthiran and Shinji Watanabe and Soumith Chintala and Vincent Quenneville-Bélair and Yangyang Shi},
journal={arXiv preprint arXiv:2110.15018},
year={2021}
}
@inproceedings{lu2022towards,
title={Towards low-distortion multi-channel speech enhancement: The ESPNet-SE submission to the L3DAS22 challenge},
author={Lu, Yen-Ju and Cornell, Samuele and Chang, Xuankai and Zhang, Wangyou and Li, Chenda and Ni, Zhaoheng and Wang, Zhong-Qiu and Watanabe, Shinji},
booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={9201--9205},
year={2022},
organization={IEEE}
}
```
from pathlib import Path
from typing import Tuple, Union
import lightning.pytorch as pl
import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from utils import CollateFnL3DAS22
_PREFIX = "L3DAS22_Task1_"
_SUBSETS = {
"train360": ["train360_1", "train360_2"],
"train100": ["train100"],
"dev": ["dev"],
"test": ["test"],
}
_SAMPLE_RATE = 16000
class L3DAS22(Dataset):
def __init__(
self,
root: Union[str, Path],
subset: str = "train360",
min_len: int = 64000,
):
self._walker = []
if subset not in _SUBSETS:
raise ValueError(f"Expect subset to be one of ('train360', 'train100', 'dev', 'test'). Found {subset}.")
for sub_dir in _SUBSETS[subset]:
path = Path(root) / f"{_PREFIX}{sub_dir}" / "data"
files = [str(p) for p in path.glob("*_A.wav") if torchaudio.info(p).num_frames >= min_len]
if len(files) == 0:
raise RuntimeError(
f"Directory {path} is not found. Please check if the zip file has been downloaded and extracted."
)
self._walker += files
def __len__(self):
return len(self._walker)
def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, str]:
noisy_path_A = Path(self._walker[n])
noisy_path_B = str(noisy_path_A).replace("_A.wav", "_B.wav")
clean_path = noisy_path_A.parent.parent / "labels" / noisy_path_A.name.replace("_A.wav", ".wav")
transcript_path = str(clean_path).replace("wav", "txt")
waveform_noisy_A, sample_rate1 = torchaudio.load(noisy_path_A)
waveform_noisy_B, sample_rate2 = torchaudio.load(noisy_path_B)
waveform_noisy = torch.cat((waveform_noisy_A, waveform_noisy_B), dim=0)
waveform_clean, sample_rate3 = torchaudio.load(clean_path)
assert sample_rate1 == _SAMPLE_RATE and sample_rate2 == _SAMPLE_RATE and sample_rate3 == _SAMPLE_RATE
with open(transcript_path, "r") as f:
transcript = f.readline()
return waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript
class L3DAS22DataModule(pl.LightningDataModule):
def __init__(
self,
dataset_path: str,
batch_size: int,
):
super().__init__()
self.dataset_path = dataset_path
self.batch_size = batch_size
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
L3DAS22(self.dataset_path, "train360"),
L3DAS22(self.dataset_path, "train100"),
]
)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=True,
num_workers=20,
)
def val_dataloader(self):
dataset = L3DAS22(self.dataset_path, "dev")
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=False,
num_workers=1,
)
def test_dataloader(self):
dataset = L3DAS22(self.dataset_path, "test", min_len=0)
return torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=1,
)
import logging
import pathlib
from argparse import ArgumentParser
import ci_sdr
import torch
from datamodule import L3DAS22DataModule
from model import DNNBeamformer
from pesq import pesq
from pystoi import stoi
logger = logging.getLogger()
def run_eval(args):
model = DNNBeamformer()
checkpoint = torch.load(args.checkpoint_path)
new_state_dict = {}
for k in checkpoint["state_dict"].keys():
if "loss" not in k:
new_state_dict[k.replace("model.", "")] = checkpoint["state_dict"][k]
model.load_state_dict(new_state_dict)
model.eval()
data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)
if args.use_cuda:
model = model.to(device="cuda")
CI_SDR = 0.0
STOI = 0.0
PESQ = 0
with torch.no_grad():
for idx, batch in enumerate(data_module.test_dataloader()):
mixture, clean, _, _ = batch
if args.use_cuda:
mixture = mixture.cuda()
clean = clean[0]
estimate = model(mixture).cpu()
ci_sdr_v = (
-ci_sdr.pt.ci_sdr(estimate, clean, compute_permutation=False, filter_length=512, change_sign=False)
.mean()
.item()
)
clean, estimate = clean[0].numpy(), estimate[0].numpy()
stoi_v = stoi(clean, estimate, 16000, extended=False)
pesq_v = pesq(16000, clean, estimate, "wb")
CI_SDR += (1.0 / float(idx + 1)) * (ci_sdr_v - CI_SDR)
STOI += (1.0 / float(idx + 1)) * (stoi_v - STOI)
PESQ += (1.0 / float(idx + 1)) * (pesq_v - PESQ)
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; Ci-SDR: {CI_SDR}, stoi: {STOI}, pesq: {PESQ}")
# visualize and save results
results = {"Ci-SDR": CI_SDR, "stoi": STOI, "pesq": PESQ}
print("*******************************")
print("RESULTS")
for i in results:
print(i, results[i])
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
required=True,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
help="Path to L3DAS22 datasets.",
)
parser.add_argument(
"--batch_size",
default=4,
type=int,
help="Batch size for training. (Default: 4)",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
import ci_sdr
import lightning.pytorch as pl
import torch
from asteroid.losses.stoi import NegSTOILoss
from asteroid.masknn import TDConvNet
from torchaudio.transforms import InverseSpectrogram, PSD, SoudenMVDR, Spectrogram
class DNNBeamformer(torch.nn.Module):
def __init__(self, n_fft: int = 1024, hop_length: int = 256, ref_channel: int = 0):
super().__init__()
self.stft = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
self.istft = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
self.mask_net = TDConvNet(
n_fft // 2 + 1,
2,
out_chan=n_fft // 2 + 1,
causal=False,
mask_act="linear",
norm_type="gLN",
)
self.beamformer = SoudenMVDR()
self.psd = PSD()
self.ref_channel = ref_channel
def forward(self, mixture) -> torch.Tensor:
spectrum = self.stft(mixture) # (batch, channel, time, freq)
batch, _, freq, time = spectrum.shape
input_feature = torch.log(spectrum[:, self.ref_channel].abs() + 1e-8) # (batch, freq, time)
mask = torch.nn.functional.relu(self.mask_net(input_feature)) # (batch, 2, freq, time)
mask_speech = mask[:, 0]
mask_noise = mask[:, 1]
psd_speech = self.psd(spectrum, mask_speech)
psd_noise = self.psd(spectrum, mask_noise)
enhanced_stft = self.beamformer(spectrum, psd_speech, psd_noise, self.ref_channel)
enhanced_waveform = self.istft(enhanced_stft, length=mixture.shape[-1])
return enhanced_waveform
class DNNBeamformerLightningModule(pl.LightningModule):
def __init__(self, model: torch.nn.Module):
super(DNNBeamformerLightningModule, self).__init__()
self.model = model
self.loss_stoi = NegSTOILoss(16000)
def training_step(self, batch, batch_idx):
mixture, clean = batch
estimate = self.model(mixture)
loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
loss_stoi = self.loss_stoi(estimate, clean).mean()
loss = loss_cisdr + loss_stoi * 10
self.log("train/loss_cisdr", loss_cisdr.item())
self.log("train/loss_stoi", loss_stoi.item())
self.log("train/loss", loss.item())
return loss
def validation_step(self, batch, batch_idx):
mixture, clean = batch
estimate = self.model(mixture)
loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
loss_stoi = self.loss_stoi(estimate, clean).mean()
loss = loss_cisdr + loss_stoi * 10
self.log("val/loss_cisdr", loss_cisdr.item())
self.log("val/loss_stoi", loss_stoi.item())
self.log("val/loss", loss.item())
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=1e-8)
return {
"optimizer": optimizer,
}
import pathlib
from argparse import ArgumentParser
import lightning.pytorch as pl
from datamodule import L3DAS22DataModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from model import DNNBeamformer, DNNBeamformerLightningModule
def run_train(args):
pl.seed_everything(1)
logger = TensorBoardLogger(args.exp_dir)
callbacks = [
ModelCheckpoint(
args.checkpoint_path,
monitor="val/loss",
save_top_k=5,
mode="min",
save_last=True,
),
]
trainer = pl.trainer.trainer.Trainer(
max_epochs=args.epochs,
callbacks=callbacks,
accelerator="gpu",
devices=args.gpus,
accumulate_grad_batches=1,
logger=logger,
gradient_clip_val=5,
check_val_every_n_epoch=1,
num_sanity_val_steps=0,
log_every_n_steps=1,
)
model = DNNBeamformer()
model_module = DNNBeamformerLightningModule(model)
data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)
trainer.fit(model_module, datamodule=data_module)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp/"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp/')",
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
help="Path to L3DAS22 datasets.",
required=True,
)
parser.add_argument(
"--batch_size",
default=4,
type=int,
help="Batch size for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=1,
type=int,
help="Number of GPUs per node to use for training. (Default: 1)",
)
parser.add_argument(
"--epochs",
default=100,
type=int,
help="Number of epochs to train for. (Default: 100)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
from typing import Dict, List, Tuple
import torch
from torch import Tensor
class CollateFnL3DAS22:
"""The collate class for L3DAS22 dataset.
Args:
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
audio_length: int = 16000 * 4,
rand_crop: bool = True,
) -> None:
self.audio_length = audio_length
self.rand_crop = rand_crop
def __call__(self, batch: List[Tuple[Tensor, Tensor, int, str]]) -> Dict:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
The list of tuples that contains:
- mixture waveforms
- clean waveform
- sample rate
- transcript
Returns:
Dictionary
"input": Tuple of waveforms and lengths.
waveforms Tensor with dimensions `(batch, time)`.
lengths Tensor with dimension `(batch,)`.
"label": None
"""
waveforms_noisy, waveforms_clean = [], []
for sample in batch:
waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript = sample
if self.rand_crop:
diff = waveform_noisy.size(-1) - self.audio_length
frame_offset = torch.randint(diff, size=(1,))
else:
frame_offset = 0
waveform_noisy = waveform_noisy[:, frame_offset : frame_offset + self.audio_length]
waveform_clean = waveform_clean[:, frame_offset : frame_offset + self.audio_length]
waveforms_noisy.append(waveform_noisy.unsqueeze(0))
waveforms_clean.append(waveform_clean)
waveforms_noisy = torch.cat(waveforms_noisy)
waveforms_clean = torch.cat(waveforms_clean)
return waveforms_noisy, waveforms_clean
......@@ -32,6 +32,7 @@ class BucketizeBatchSampler(BatchSampler):
(Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default: True)
seed (int, optional): The seed for initialzing RNG. Only used when `shuffle` is True. (Default: 0)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
......@@ -45,7 +46,7 @@ class BucketizeBatchSampler(BatchSampler):
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer to enable shuffling every epoch.
in pytorch_lightning Trainer and set ``seed`` to ``self.trainer.current_epoch`` to enable shuffling every epoch.
"""
def __init__(
......@@ -57,6 +58,7 @@ class BucketizeBatchSampler(BatchSampler):
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if max_len is None:
......@@ -82,6 +84,10 @@ class BucketizeBatchSampler(BatchSampler):
self.max_token_count = max_token_count
self.batch_size = batch_size
self.shuffle = shuffle
self.seed = seed
if self.shuffle:
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self._update_iter_list()
......@@ -115,7 +121,7 @@ class BucketizeBatchSampler(BatchSampler):
def _update_iter_list(self) -> None:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))]
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0), generator=self.g)]
self.iter_list = []
total_len = 0
batch = []
......@@ -193,7 +199,18 @@ class DistributedBatchSampler(DistributedSampler):
self.epoch = 0
self.seed = seed
self.drop_last = drop_last
if shuffle:
self.shuffle = shuffle
indices = self.batch_sampler.iter_list
if self.drop_last and len(indices) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
......@@ -210,7 +227,6 @@ class DistributedBatchSampler(DistributedSampler):
self.subset = indices[self.rank : self.total_size : self.num_replicas]
assert len(self.subset) == self.num_samples
def __iter__(self):
return iter(self.subset)
def __len__(self):
......
......@@ -12,11 +12,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple
from lightning import HuBERTFineTuneModule
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from lightning_modules import HuBERTFineTuneModule
logger = logging.getLogger(__name__)
......@@ -56,10 +55,10 @@ def run_train(args):
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
devices=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
strategy="ddp_find_unused_parameters_true",
use_distributed_sampler=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
val_check_interval=500,
......
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.models.wav2vec2.components as components
from dataset import (
_get_lengths_librilightlimited,
_get_lengths_librispeech,
BucketizeBatchSampler,
CollateFnHubert,
CollateFnLibriLightLimited,
DistributedBatchSampler,
HuBERTDataSet,
)
from lightning.pytorch import LightningModule
from loss import hubert_loss
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor, Tensor]
Batch_FineTune = Tuple[Tensor, Tensor, Tensor, Tensor]
class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear learning rate scheduler with warm up."""
def __init__(
self,
optimizer: Optimizer,
warmup_updates: int,
max_updates: int,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.max_updates = max_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count <= self.warmup_updates:
return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs]
elif self._step_count >= self.max_updates:
return [0.0 for _ in self.base_lrs]
else:
pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates)
return [base_lr * pct_remaining for base_lr in self.base_lrs]
class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear learning rate scheduler with warmup, hold, and decay."""
def __init__(
self,
optimizer: Optimizer,
warmup_updates: int,
hold_updates: int,
decay_updates: int,
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.05,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.hold_updates = hold_updates
self.decay_updates = decay_updates
self.init_lr_scale = init_lr_scale
self.final_lr_scale = final_lr_scale
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count <= self.warmup_updates:
return [
base_lr * (self.init_lr_scale + self._step_count / self.warmup_updates * (1 - self.init_lr_scale))
for base_lr in self.base_lrs
]
elif self.warmup_updates < self._step_count <= (self.warmup_updates + self.hold_updates):
return list(self.base_lrs)
elif self._step_count <= (self.warmup_updates + self.hold_updates + self.decay_updates):
return [
base_lr
* math.exp(
math.log(self.final_lr_scale)
* (self._step_count - self.warmup_updates - self.hold_updates)
/ self.decay_updates
)
for base_lr in self.base_lrs
]
else:
return [base_lr * self.final_lr_scale for base_lr in self.base_lrs]
def _compute_accuracy(logits: torch.Tensor):
with torch.no_grad():
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
def _reset_stats():
return {
"train": {
"correct": 0.0,
"count": 0.0,
},
"val": {
"correct": 0.0,
"count": 0.0,
},
}
class HuBERTPreTrainModule(LightningModule):
def __init__(
self,
*,
model_name: str,
feature_grad_mult: float,
num_classes: int,
dataset: str,
dataset_path: str,
feature_type: str,
seconds_per_batch: float,
learning_rate: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
clip_norm: Optional[float],
warmup_updates: int,
max_updates: int,
):
super().__init__()
if model_name == "hubert_pretrain_base":
self.model = torchaudio.models.hubert_pretrain_base(
feature_grad_mult=feature_grad_mult, num_classes=num_classes
)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large()
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge()
else:
raise ValueError(f"Unsupported model name: {model_name}")
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
self.loss = hubert_loss
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
)
self.clip_norm = clip_norm
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset
self.dataset_path = dataset_path
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
self.nan_loss_count = 0.0
def _step(self, batch: Batch, batch_idx, step_type):
if batch is None:
return None, None
waveforms, labels, audio_lengths = batch
if step_type == "val":
with torch.no_grad():
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
else:
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
loss = self.loss(logit_m, logit_u, feature_penalty)
if not torch.isinf(loss) and not torch.isnan(loss):
self.log(f"{step_type}_loss", loss.item() / logit_m.size(0), on_step=True, on_epoch=True)
else:
self.nan_loss_count += 1
self.log("nan_loss_count", self.nan_loss_count, on_step=True, on_epoch=True)
# log accuracies of masked and unmasked frames
correct_m, count_m = _compute_accuracy(logit_m)
correct_u, count_u = _compute_accuracy(logit_u)
self.mask_stats[step_type]["correct"] += correct_m
self.mask_stats[step_type]["count"] += count_m
self.unmask_stats[step_type]["correct"] += correct_u
self.unmask_stats[step_type]["count"] += count_u
self.log(
f"{step_type}_masked_accuracy",
self.mask_stats[step_type]["correct"] / self.mask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
self.log(
f"{step_type}_unmasked_accuracy",
self.unmask_stats[step_type]["correct"] / self.unmask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
return loss, logit_m.size(0)
def configure_optimizers(self):
return (
[self.optimizer],
[
{
"scheduler": self.lr_scheduler,
"interval": "step",
},
],
)
def training_step(self, batch: Batch, batch_idx):
"""Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / num_frames), where num_frames is
the sum of masked frames across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / num_frames.
- Update parameters on each GPU.
Doing so allows us to account for the variability in number of masked frames in
variable-length sequential data.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=True):
loss, num_frame = self._step(batch, batch_idx, "train")
if torch.isinf(loss) or torch.isnan(loss):
opt.zero_grad()
return None
# normalize the loss based on the sum of num_frame across all GPUs
num_frames = self.all_gather(num_frame)
self.log("Gathered number of frames", num_frames.float().sum(), on_step=True, on_epoch=True)
loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
if self.clip_norm is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
return loss
def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val")[0]
def on_validation_end(self):
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
def train_dataloader(self):
dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
seed=self.trainer.current_epoch,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader
def val_dataloader(self):
dataset = HuBERTDataSet(self.dataset_path, self.dataset, "valid")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=10,
)
return dataloader
class HuBERTFineTuneModule(LightningModule):
def __init__(
self,
*,
model_name: str,
encoder_projection_dropout: float,
encoder_attention_dropout: float,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_drop: float,
mask_prob: float,
mask_channel_prob: float,
mask_channel_length: float,
num_classes: int,
aux_num_out: int,
checkpoint: str,
dataset_path: str,
seconds_per_batch: float,
subset: str,
learning_rate: float,
betas: Tuple[float, float],
adam_eps: float,
weight_decay: float,
freeze_encoder_updates: int,
warmup_updates: int,
hold_updates: int,
decay_updates: int,
):
super().__init__()
if model_name == "hubert_pretrain_base":
self.model = torchaudio.models.hubert_pretrain_base(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
self.aux = torch.nn.Linear(768, aux_num_out)
elif model_name == "hubert_pretrain_large":
self.model = torchaudio.models.hubert_pretrain_large(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
self.aux = torch.nn.Linear(1024, aux_num_out)
elif model_name == "hubert_pretrain_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge(
encoder_projection_dropout=encoder_projection_dropout,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_drop=encoder_layer_drop,
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
self.aux = torch.nn.Linear(1280, aux_num_out)
else:
raise ValueError(f"Unsupported model name: {model_name}.")
self._load_checkpoint(checkpoint)
for p in self.model.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
self.optimizer = torch.optim.AdamW(
list(self.aux.parameters()) + list(self.model.parameters()),
lr=learning_rate,
betas=betas,
eps=adam_eps,
weight_decay=weight_decay,
)
self.freeze_encoder_updates = freeze_encoder_updates
self.lr_scheduler = TriStageLRScheduler(self.optimizer, warmup_updates, hold_updates, decay_updates)
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.subset = subset
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
def _load_checkpoint(self, checkpoint):
# load pretrain model from checkpoint
state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
state_dict = state_dict["state_dict"]
s = {}
for k in state_dict:
if "model." in k:
s[k.replace("model.", "")] = state_dict[k]
self.model.load_state_dict(s)
def _step(self, batch: Batch_FineTune, batch_idx, step_type):
if batch is None:
return None
waveforms, labels, audio_lengths, label_lengths = batch
if self.global_step <= self.freeze_encoder_updates:
with torch.no_grad():
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
padding_mask = components._get_padding_mask(x, out_len)
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
else:
with torch.no_grad():
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
padding_mask = components._get_padding_mask(x, out_len)
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
logits = self.aux(x)
log_probs = F.log_softmax(logits, dim=-1)
log_probs = log_probs.transpose(0, 1)
loss = self.loss_fn(
log_probs,
labels,
out_len,
label_lengths,
)
self.log(f"{step_type}_loss", loss.item() / waveforms.size(0), on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[
self.optimizer,
],
[
{"scheduler": self.lr_scheduler, "interval": "step"},
],
)
def training_step(self, batch: Batch_FineTune, batch_idx):
"""Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=True):
loss = self._step(batch, batch_idx, "train")
# normalize the loss based on the sum of batch_sie across all GPUs
batch_size = batch[0].size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
def validation_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "val")
def train_dataloader(self):
dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
lengths = _get_lengths_librilightlimited(dataset._fileids_paths, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler(
lengths,
num_buckets=100,
max_token_count=self.seconds_per_batch * 16000,
shuffle=True,
seed=self.global_step,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.global_step)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnLibriLightLimited(),
num_workers=10,
)
return dataloader
def val_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.dataset_path, "dev-other")
lengths = _get_lengths_librispeech(dataset._walker, dataset._path, dataset._ext_audio)
sampler = BucketizeBatchSampler(
lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=False
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnLibriLightLimited(),
num_workers=10,
)
return dataloader
......@@ -9,10 +9,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple
from lightning import HuBERTPreTrainModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning_modules import HuBERTPreTrainModule
logger = logging.getLogger(__name__)
......@@ -52,10 +52,10 @@ def run_train(args):
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
devices=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
strategy="ddp_find_unused_parameters_true",
use_distributed_sampler=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
......
......@@ -35,7 +35,7 @@ class Pipeline(torch.nn.Module):
rir, _ = torchaudio.sox_effects.apply_effects_tensor(
self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]
)
rir = rir / torch.norm(rir, p=2)
rir = rir / torch.linalg.vector_norm(rir, ord=2)
rir = torch.flip(rir, [1])
# 4. Apply RIR filter
......
......@@ -70,8 +70,8 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None
available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
available_symbol_set = {"english_characters", "english_phonemes"}
available_phonemizers = {"DeepPhonemizer"}
def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
......
......@@ -130,6 +130,8 @@ def parse_args():
help="learning rate exponential decay constant",
)
parser.add_argument("--momentum", default=0.8, type=float, metavar="M", help="momentum")
parser.add_argument("--beta_1", default=0.9, type=float, metavar="BETA_1", help="beta_1")
parser.add_argument("--beta_2", default=0.999, type=float, metavar="BETA_2", help="beta_2")
parser.add_argument("--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay")
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
......@@ -472,15 +474,17 @@ def main(rank, args):
optimizer = Adam(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
betas=(args.beta_1, args.beta_2),
weight_decay=args.weight_decay,
eps=args.eps,
)
elif args.optimizer == "adamw":
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
betas=(args.beta_1, args.beta_2),
weight_decay=args.weight_decay,
eps=args.eps,
)
else:
raise ValueError("Selected optimizer not supported")
......
# Modularized Self-supervised Learning Recipe
This directory contains the modularized training recipe for audio/speech self-supervised learning. The principle is to let users easily inject a new component (model, data_module, loss function, etc) to the existing recipe for different tasks (e.g. Wav2Vec 2.0, HuBERT, etc).
## HuBERT Pre-training Example
To get the K-Means labels for HuBERT pre-training, please check the [pre-processing step](../hubert/README.md#pre-processing-1st-iteration) in hubert example.
In order to run the HuBERT pre-training script for the first iteration, users need to go to `examples` directory and run the following SLURM command:
```
cd examples
srun \
--gpus-per-node=8 \
--ntasks-per-node=8 \
-N 4 \
--cpus-per-task=10 \
python -m self_supervised_learning.train_hubert \
--dataset-path hubert/exp/data/mfcc/ \
--exp-dir self_supervised_learning/exp_iter1 \
--feature-type mfcc \
--num-class 100 \
--max-updates 250000 \
--learning-rate 0.0005 \
--gpus 8 \
--num-nodes 4
```
from ._hubert_datamodule import HuBERTDataModule
__all__ = [
"HuBERTDataModule",
"Wav2Vec2DataModule",
]
import torch
from pytorch_lightning import LightningDataModule
from ._utils import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet
class HuBERTDataModule(LightningDataModule):
hubert_cls = HuBERTDataSet
def __init__(
self,
*,
dataset_path,
dataset,
feature_type,
seconds_per_batch,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.dataset_path = dataset_path
self.dataset = dataset
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = self.hubert_cls(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
seed=self.trainer.current_epoch,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
def val_dataloader(self):
dataset = self.hubert_cls(self.dataset_path, self.dataset, "valid")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
import math
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torchaudio
from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from ..lightning_modules import Batch
class BucketizeBatchSampler(BatchSampler):
"""Buketized BatchSampler for sequential data with different lengths to reduce number of paddings.
Args:
lengths (List[int]): The lengths of the samples in the dataset.
num_buckets (int): The number of buckets to split the data samples.
min_len (int, optional): The minimum sample lengths to keep.
(Default: 0)
max_len (int or None, optional): The maximum sample lengths to keep. Inferred if not provided.
(Default ``None``)
max_token_count (int or None, optional): The max number of tokens in one mini-batch.
(Default: ``None``)
batch_size (int or None, optional): The number of samples in one mini-batch.
(Default: ``None``)
shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling.
(Default: True)
seed (int, optional): The seed for initialzing RNG. Only used when `shuffle` is True. (Default: 0)
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
(Default: False)
Note:
``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two
should have value.
Note:
``drop_last`` is only valid when ``batch_size`` argument is given.
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer and set ``seed`` to ``self.trainer.current_epoch`` to enable shuffling every epoch.
"""
def __init__(
self,
lengths: List[int],
num_buckets: int,
min_len: int = 0,
max_len: Optional[int] = None,
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if max_len is None:
max_len = max(lengths)
if not (0 <= min_len <= max_len):
raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``")
if max_token_count is not None and batch_size is not None:
raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
if max_token_count is None and batch_size is None:
raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.")
if max_token_count is not None:
assert (
max_len <= max_token_count
), "The ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``."
# Filter out samples which are outside the bounds of [min_len, max_len]
filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len]
if len(filtered_length_idx) == 0:
raise AssertionError("``lengths`` cannot be empty after filtering.")
sorted_filtered_length_idx = sorted(filtered_length_idx, key=lambda x: x[0])
self.lengths = [e[0] for e in sorted_filtered_length_idx]
self.indices = [e[1] for e in sorted_filtered_length_idx]
self.max_token_count = max_token_count
self.batch_size = batch_size
self.shuffle = shuffle
self.seed = seed
if self.shuffle:
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.drop_last = drop_last
self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len)
self._update_iter_list()
def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset.
Args:
lengths (List[int]): The lengths of the samples in the dataset.
num_buckets (int): The number of buckets.
min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width.
max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width.
Returns:
(dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is
the Tensor of corresponding sample indices.
"""
buckets = {}
boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1)
bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries)
for i in range(bucket_ids.size(0)):
bucket_id = int(bucket_ids[i])
if bucket_id in buckets:
buckets[bucket_id].append(i)
else:
buckets[bucket_id] = [i]
for k in buckets:
buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int)
buckets = {k: v for k, v in sorted(buckets.items())}
return buckets
def _update_iter_list(self) -> None:
if self.shuffle:
for k in self.buckets:
self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0), generator=self.g)]
self.iter_list = []
total_len = 0
batch = []
max_batch_size = self.max_token_count if self.max_token_count else self.batch_size
for k in self.buckets:
for i in range(self.buckets[k].size(0)):
index = int(self.buckets[k][i])
sample_length = self.lengths[index] if self.max_token_count else 1
if total_len + sample_length <= max_batch_size:
batch.append(self.indices[index])
total_len += sample_length
else:
self.iter_list.append(batch)
batch = [self.indices[index]]
total_len = sample_length
if len(batch) > 0 and (self.max_token_count or not self.drop_last):
self.iter_list.append(batch)
def __iter__(self) -> Iterator[List[int]]:
return iter(self.iter_list)
def __len__(self):
if self.batch_size or (self.max_token_count and not self.shuffle):
return len(self.iter_list)
class DistributedBatchSampler(DistributedSampler):
"""`BucketizeBatchSampler` wrapper that distributes across each processor.
Args:
batch_sampler (BucketizeBatchSampler): the initialized bucketize batch sampler.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): if ``True``, the list of batch indices will be shuffled.
(Default: ``True``)
seed (int, optional): random seed used to shuffle the batch_sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. (Default: ``0``)
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. (Default: ``False``)
Note:
if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1``
in pytorch_lightning Trainer, and set `sampler.set_epoch(self.current_epoch)` before DataLoader initialization
in `train_dataloader` method to enable shuffling every epoch.
"""
def __init__(
self,
batch_sampler: BucketizeBatchSampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
self.batch_sampler = batch_sampler
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.epoch = 0
self.seed = seed
self.drop_last = drop_last
self.shuffle = shuffle
indices = self.batch_sampler.iter_list
if self.drop_last and len(indices) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist()
indices = [self.batch_sampler.iter_list[i] for i in perm]
else:
indices = self.batch_sampler.iter_list
if self.drop_last:
self.total_size = len(indices) - len(indices) % self.num_replicas
else:
padding_size = self.num_replicas - len(indices) % self.num_replicas
indices += indices[:padding_size]
self.total_size = len(indices)
self.num_samples = self.total_size // self.num_replicas
self.subset = indices[self.rank : self.total_size : self.num_replicas]
assert len(self.subset) == self.num_samples
return iter(self.subset)
def __len__(self):
return self.num_samples
class HuBERTDataSet(Dataset):
"""Create a Dataset for HuBERT model training and fine-tuning.
Args:
exp_dir (str or Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
"""
def __init__(
self,
exp_dir: Union[str, Path],
dataset: str,
subset: str,
) -> None:
self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset)
def __len__(self):
return len(self.f_list)
def _get_lists(
self,
tsv_dir: Path,
dataset: str,
subset: str,
) -> Tuple[List[Path], List[int], List[int]]:
"""Get the list of paths for iteration.
Args:
tsv_dir (Path): The root directory of the ``.tsv`` file list.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
Returns:
(numpy.array) List of file paths.
(numpy.array) List of indices.
(numpy.array) List of waveform lengths.
"""
f_ind_len_list = []
with open(tsv_dir / f"{dataset}_{subset}.tsv") as f:
root = f.readline().rstrip()
for index, line in enumerate(f):
path, nsample = line.split("\t")
path = f"{root}/{path}"
nsample = int(nsample)
f_ind_len_list.append((path, index, nsample))
f_list, ind_list, len_list = [], [], []
for ele in f_ind_len_list:
f_list.append(ele[0])
ind_list.append(ele[1])
len_list.append(ele[2])
return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list)
def _load_audio(self, index: int) -> Tensor:
"""Load waveform given the sample index of the dataset.
Args:
index (int): The sample index.
Returns:
(Tensor): The corresponding waveform Tensor.
"""
wav_path = self.f_list[index]
waveform, sample_rate = torchaudio.load(wav_path)
assert waveform.shape[1] == self.len_list[index]
return waveform
def _load_labels(self, label_dir: Path, dataset: str, subset: str) -> np.array:
"""Load all labels to memory into a numpy array.
Args:
label_dir (Path): The directory that contains the label file.
dataset (str): The dataset for training. Options: [``librispeech``, ``librilight``].
subset (str): The subset of the dataset. Options: [``train``, ``valid``].
Returns:
(np.array): The numpy arrary that contains the labels for each audio file.
"""
with open(label_dir / f"label_{subset}.pt") as f:
labels = [line.rstrip() for line in f]
labels = [labels[i] for i in self.ind_list]
return np.asarray(labels, dtype=np.string_)
def __getitem__(self, index):
waveform = self._load_audio(index)
length = waveform.shape[1]
label = [int(ele) for ele in self.labels[index].split()]
label = torch.tensor(label)
return (waveform, label, length)
def _crop_audio_label(
waveform: Tensor,
label: Optional[Tensor],
length: Tensor,
num_frames: int,
rand_crop: bool,
) -> Tuple[Tensor, Optional[Tensor], Tensor]:
"""Collate the audio and label at the same time.
Args:
waveform (Tensor): The waveform Tensor with dimensions `(1, time)`.
label (Tensor, optional): The label Tensor with dimensions `(1, seq)`.
length (Tensor): The length Tensor with dimension `(1,)`.
num_frames (int): The final length of the waveform.
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
Returns:
(Tuple(Tensor, (Tensor, optional), Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""
kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
frame_offset = 0
waveform = waveform[0]
if waveform.size(0) > num_frames and rand_crop:
diff = waveform.size(0) - num_frames
frame_offset = torch.randint(diff, size=(1,))
elif waveform.size(0) < num_frames:
num_frames = waveform.size(0)
if label is not None:
label_offset = max(
math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate)) + 1,
0,
)
num_label = math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate)) + 1
label = label[label_offset : label_offset + num_label]
waveform = waveform[frame_offset : frame_offset + num_frames]
length = num_frames
return waveform, label, length
class CollateFnHubert:
"""The collate class for HuBERT pre-training and fine-tuning.
Args:
feature_type (str): The type of features for KMeans clustering.
Options: [``mfcc``, ``hubert``].
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
feature_type: str,
pad: bool = False,
rand_crop: bool = True,
) -> None:
self.feature_type = feature_type
self.pad = pad
self.rand_crop = rand_crop
def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Dict:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
The list of tuples that contains the waveforms, labels, and audio lengths.
Returns:
Dictionary
"input": Tuple of waveforms and lengths.
waveforms Tensor with dimensions `(batch, time)`.
lengths Tensor with dimension `(batch,)`.
"label": Tuple of label Tensor with dimensions `(batch, seq)`.
"""
if self.pad:
num_frames = max([sample[0].shape[1] for sample in batch])
else:
num_frames = min([sample[0].shape[1] for sample in batch])
waveforms, labels, lengths = [], [], []
for sample in batch:
waveform, label, length = sample
# The MFCC feature is 10ms per frame, while the HuBERT's transformer output
# is 20ms per frame. Downsample the KMeans label if it's generated by MFCC features.
if self.feature_type == "mfcc":
label = label[::2]
waveform, label, length = _crop_audio_label(waveform, label, length, num_frames, self.rand_crop)
waveforms.append(waveform)
lengths.append(length)
labels.append(label)
# make sure the shapes are the same if not apply zero-padding
if not self.pad:
assert all(
[waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
), "The dimensions of the waveforms should be identical in the same batch."
assert all(
[label.shape[0] == labels[0].shape[0] for label in labels]
), "The dimensions of the labels should be identical in the same batch."
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
lengths = torch.tensor(lengths)
batch = Batch((waveforms, labels, lengths), (labels,))
return batch
class CollateFnWav2Vec2:
"""The collate class for Wav2Vec2 pre-training and fine-tuning.
Args:
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
pad: bool = False,
rand_crop: bool = True,
) -> None:
self.pad = pad
self.rand_crop = rand_crop
def __call__(self, batch: List[Tuple[Tensor, Tensor, int]]) -> Dict:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
The list of tuples that contains the waveforms, labels, and audio lengths.
Returns:
Dictionary
"input": Tuple of waveforms and lengths.
waveforms Tensor with dimensions `(batch, time)`.
lengths Tensor with dimension `(batch,)`.
"label": None
"""
if self.pad:
num_frames = max([sample[0].shape[1] for sample in batch])
else:
num_frames = min([sample[0].shape[1] for sample in batch])
waveforms, lengths = [], []
for sample in batch:
waveform, length = sample
waveform, _, length = _crop_audio_label(waveform, None, length, num_frames, self.rand_crop)
waveforms.append(waveform)
lengths.append(length)
# make sure the shapes are the same if not apply zero-padding
if not self.pad:
assert all(
[waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms]
), "The dimensions of the waveforms should be identical in the same batch."
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
lengths = torch.tensor(lengths)
batch = Batch((waveforms, lengths), (None,))
return batch
import torch
from pytorch_lightning import LightningDataModule
from torchaudio.datasets.librispeech import LIBRISPEECH
from ._utils import BucketizeBatchSampler, CollateFnWav2Vec2, DistributedBatchSampler
class Wav2Vec2DataModule(LightningDataModule):
librispeech_cls = LIBRISPEECH
def __init__(
self,
*,
dataset_path,
seconds_per_batch,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.dataset_path, url="train-clean-360"),
self.librispeech_cls(self.dataset_path, url="train-clean-100"),
self.librispeech_cls(self.dataset_path, url="train-other-500"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
def val_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
from collections import namedtuple
from typing import Callable, Optional
import lightning.pytorch as pl
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
Batch = namedtuple("Batch", ["inputs", "labels"])
class SSLPretrainModule(pl.LightningModule):
def __init__(
self,
model: nn.Module,
loss_fn: Callable,
optimizer: Optimizer,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def log_metrics(self, batch: Batch, output, loss, step_type):
"""Log useful information to TensorBoard. Users are expected to
write their customized `log_metrics` method to log information
such as loss values, metric scores, etc.
Args:
batch (Batch): Batch tuple from the dataloader.
output: Output generated by the model.
loss (Tensor): Generated class
step_type (str): Type of step. Choices are "train", "val", and "test".
"""
pass
def training_step(self, batch: Batch, batch_idx):
out = self.model(*batch.inputs)
loss, num_frame = self.loss_fn(*out, *batch.labels)
self.log_metric(batch, out, loss, "train")
# normalize the loss based on the sum of num_frame across all GPUs
num_frames = self.all_gather(num_frame)
self.log(
"Gathered number of frames",
num_frames.float().sum(),
on_step=True,
on_epoch=True,
)
loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames
return loss
def validation_step(self, batch, batch_idx):
out = self.model(*batch.inputs)
loss, _ = self.loss_fn(*out, *batch.labels)
self.log_metric(batch, out, loss, "val")
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment