Commit 94f5027e authored by nateanl's avatar nateanl Committed by Facebook GitHub Bot
Browse files

Add multi-channel DNN beamforming training recipe (#3036)

Summary:
The PR adds the training recipe of DNN beamforming for multi-channel speech enhancement.

Pull Request resolved: https://github.com/pytorch/audio/pull/3036

Reviewed By: hwangjeff

Differential Revision: D45061841

Pulled By: nateanl

fbshipit-source-id: 48ede5dd579efe200669dbc83e9cb4dea809e4b4
parent d5b2996b
# Time-Frequency Mask based DNN MVDR Beamforming Example
This directory contains sample implementations of training and evaluation pipelines for an DNN Beamforming model.
The `DNNBeamformer` model composes the following componenst:
+ [`torchaudio.transforms.Spectrogram`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html#spectrogram) that applies Short-time Fourier Transform (STFT) to the waveform.
+ ConvTasNet without encoder/decoder that predicts T-F masks for speech and noise, respectively.
+ [`torchaudio.transforms.PSD`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.PSD.html#psd) that computes covariance matrices for speech and noise.
+ [`torchaudio.transforms.SoudenMVDR`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.SoudenMVDR.html#soudenmvdr) that estimates the compex-valued STFT for the enhanced speech.
+ [`torchaudio.transforms.InverseSpectrogram`](https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseSpectrogram.html#inversespectrogram) that applies inverse STFT (iSTFT) to generate the enhanced waveform.
## Usage
### Training
[`train.py`](./train.py) trains a [`DNNBeamformer`](./model.py) model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the [`L3DAS22`](https://www.kaggle.com/datasets/l3dasteam/l3das22) datasets.
### Evaluation
[`eval.py`](./eval.py) evaluates a trained [`DNNBeamformer`](./model.py) on the test subset of L3DAS22 dataset.
### L3DAS22
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --dataset-path ./datasets/L3DAS22 --checkpoint-path ./exp/checkpoints
```
Sample SLURM command for evaluation:
```
srun python eval.py --checkpoint-path ./exp/checkpoints/epoch=97-step=780472.ckpt --dataset-path ./datasets/L3DAS22 --use-cuda
```
Using the sample training command above, [`train.py`](./train.py) produces a model with 5.0M parameters (57.5MB).
The table below contains Ci-SDR, STOI, and PESQ results for the test subset of `L3DAS22` dataset.
| Ci-SDR | STOI | PESQ |
|:-------------------:|-------------:|-------------:|
| 19.00 | 0.82 | 2.46 |
If you find this training recipe useful, please cite as:
```bibtex
@article{yang2021torchaudio,
title={TorchAudio: Building Blocks for Audio and Speech Processing},
author={Yao-Yuan Yang and Moto Hira and Zhaoheng Ni and Anjali Chourdia and Artyom Astafurov and Caroline Chen and Ching-Feng Yeh and Christian Puhrsch and David Pollack and Dmitriy Genzel and Donny Greenberg and Edward Z. Yang and Jason Lian and Jay Mahadeokar and Jeff Hwang and Ji Chen and Peter Goldsborough and Prabhat Roy and Sean Narenthiran and Shinji Watanabe and Soumith Chintala and Vincent Quenneville-Bélair and Yangyang Shi},
journal={arXiv preprint arXiv:2110.15018},
year={2021}
}
@inproceedings{lu2022towards,
title={Towards low-distortion multi-channel speech enhancement: The ESPNet-SE submission to the L3DAS22 challenge},
author={Lu, Yen-Ju and Cornell, Samuele and Chang, Xuankai and Zhang, Wangyou and Li, Chenda and Ni, Zhaoheng and Wang, Zhong-Qiu and Watanabe, Shinji},
booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={9201--9205},
year={2022},
organization={IEEE}
}
```
from pathlib import Path
from typing import Tuple, Union
import lightning.pytorch as pl
import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from utils import CollateFnL3DAS22
_PREFIX = "L3DAS22_Task1_"
_SUBSETS = {
"train360": ["train360_1", "train360_2"],
"train100": ["train100"],
"dev": ["dev"],
"test": ["test"],
}
_SAMPLE_RATE = 16000
class L3DAS22(Dataset):
def __init__(
self,
root: Union[str, Path],
subset: str = "train360",
min_len: int = 64000,
):
self._walker = []
if subset not in _SUBSETS:
raise ValueError(f"Expect subset to be one of ('train360', 'train100', 'dev', 'test'). Found {subset}.")
for sub_dir in _SUBSETS[subset]:
path = Path(root) / f"{_PREFIX}{sub_dir}" / "data"
files = [str(p) for p in path.glob("*_A.wav") if torchaudio.info(p).num_frames >= min_len]
if len(files) == 0:
raise RuntimeError(
f"Directory {path} is not found. Please check if the zip file has been downloaded and extracted."
)
self._walker += files
def __len__(self):
return len(self._walker)
def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, str]:
noisy_path_A = Path(self._walker[n])
noisy_path_B = str(noisy_path_A).replace("_A.wav", "_B.wav")
clean_path = noisy_path_A.parent.parent / "labels" / noisy_path_A.name.replace("_A.wav", ".wav")
transcript_path = str(clean_path).replace("wav", "txt")
waveform_noisy_A, sample_rate1 = torchaudio.load(noisy_path_A)
waveform_noisy_B, sample_rate2 = torchaudio.load(noisy_path_B)
waveform_noisy = torch.cat((waveform_noisy_A, waveform_noisy_B), dim=0)
waveform_clean, sample_rate3 = torchaudio.load(clean_path)
assert sample_rate1 == _SAMPLE_RATE and sample_rate2 == _SAMPLE_RATE and sample_rate3 == _SAMPLE_RATE
with open(transcript_path, "r") as f:
transcript = f.readline()
return waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript
class L3DAS22DataModule(pl.LightningDataModule):
def __init__(
self,
dataset_path: str,
batch_size: int,
):
super().__init__()
self.dataset_path = dataset_path
self.batch_size = batch_size
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
L3DAS22(self.dataset_path, "train360"),
L3DAS22(self.dataset_path, "train100"),
]
)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=True,
num_workers=20,
)
def val_dataloader(self):
dataset = L3DAS22(self.dataset_path, "dev")
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=False,
num_workers=1,
)
def test_dataloader(self):
dataset = L3DAS22(self.dataset_path, "test", min_len=0)
return torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=1,
)
import logging
import pathlib
from argparse import ArgumentParser
import ci_sdr
import torch
from datamodule import L3DAS22DataModule
from model import DNNBeamformer
from pesq import pesq
from pystoi import stoi
logger = logging.getLogger()
def run_eval(args):
model = DNNBeamformer()
checkpoint = torch.load(args.checkpoint_path)
new_state_dict = {}
for k in checkpoint["state_dict"].keys():
if "loss" not in k:
new_state_dict[k.replace("model.", "")] = checkpoint["state_dict"][k]
model.load_state_dict(new_state_dict)
model.eval()
data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)
if args.use_cuda:
model = model.to(device="cuda")
CI_SDR = 0.0
STOI = 0.0
PESQ = 0
with torch.no_grad():
for idx, batch in enumerate(data_module.test_dataloader()):
mixture, clean, _, _ = batch
if args.use_cuda:
mixture = mixture.cuda()
clean = clean[0]
estimate = model(mixture).cpu()
ci_sdr_v = (
-ci_sdr.pt.ci_sdr(estimate, clean, compute_permutation=False, filter_length=512, change_sign=False)
.mean()
.item()
)
clean, estimate = clean[0].numpy(), estimate[0].numpy()
stoi_v = stoi(clean, estimate, 16000, extended=False)
pesq_v = pesq(16000, clean, estimate, "wb")
CI_SDR += (1.0 / float(idx + 1)) * (ci_sdr_v - CI_SDR)
STOI += (1.0 / float(idx + 1)) * (stoi_v - STOI)
PESQ += (1.0 / float(idx + 1)) * (pesq_v - PESQ)
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; Ci-SDR: {CI_SDR}, stoi: {STOI}, pesq: {PESQ}")
# visualize and save results
results = {"Ci-SDR": CI_SDR, "stoi": STOI, "pesq": PESQ}
print("*******************************")
print("RESULTS")
for i in results:
print(i, results[i])
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
required=True,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
help="Path to L3DAS22 datasets.",
)
parser.add_argument(
"--batch_size",
default=4,
type=int,
help="Batch size for training. (Default: 4)",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
import ci_sdr
import lightning.pytorch as pl
import torch
from asteroid.losses.stoi import NegSTOILoss
from asteroid.masknn import TDConvNet
from torchaudio.transforms import InverseSpectrogram, PSD, SoudenMVDR, Spectrogram
class DNNBeamformer(torch.nn.Module):
def __init__(self, n_fft: int = 1024, hop_length: int = 256, ref_channel: int = 0):
super().__init__()
self.stft = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
self.istft = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
self.mask_net = TDConvNet(
n_fft // 2 + 1,
2,
out_chan=n_fft // 2 + 1,
causal=False,
mask_act="linear",
norm_type="gLN",
)
self.beamformer = SoudenMVDR()
self.psd = PSD()
self.ref_channel = ref_channel
def forward(self, mixture) -> torch.Tensor:
spectrum = self.stft(mixture) # (batch, channel, time, freq)
batch, _, freq, time = spectrum.shape
input_feature = torch.log(spectrum[:, self.ref_channel].abs() + 1e-8) # (batch, freq, time)
mask = torch.nn.functional.relu(self.mask_net(input_feature)) # (batch, 2, freq, time)
mask_speech = mask[:, 0]
mask_noise = mask[:, 1]
psd_speech = self.psd(spectrum, mask_speech)
psd_noise = self.psd(spectrum, mask_noise)
enhanced_stft = self.beamformer(spectrum, psd_speech, psd_noise, self.ref_channel)
enhanced_waveform = self.istft(enhanced_stft, length=mixture.shape[-1])
return enhanced_waveform
class DNNBeamformerLightningModule(pl.LightningModule):
def __init__(self, model: torch.nn.Module):
super(DNNBeamformerLightningModule, self).__init__()
self.model = model
self.loss_stoi = NegSTOILoss(16000)
def training_step(self, batch, batch_idx):
mixture, clean = batch
estimate = self.model(mixture)
loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
loss_stoi = self.loss_stoi(estimate, clean).mean()
loss = loss_cisdr + loss_stoi * 10
self.log("train/loss_cisdr", loss_cisdr.item())
self.log("train/loss_stoi", loss_stoi.item())
self.log("train/loss", loss.item())
return loss
def validation_step(self, batch, batch_idx):
mixture, clean = batch
estimate = self.model(mixture)
loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
loss_stoi = self.loss_stoi(estimate, clean).mean()
loss = loss_cisdr + loss_stoi * 10
self.log("val/loss_cisdr", loss_cisdr.item())
self.log("val/loss_stoi", loss_stoi.item())
self.log("val/loss", loss.item())
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=1e-8)
return {
"optimizer": optimizer,
}
import pathlib
from argparse import ArgumentParser
import lightning.pytorch as pl
from datamodule import L3DAS22DataModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from model import DNNBeamformer, DNNBeamformerLightningModule
def run_train(args):
pl.seed_everything(1)
logger = TensorBoardLogger(args.exp_dir)
callbacks = [
ModelCheckpoint(
args.checkpoint_path,
monitor="val/loss",
save_top_k=5,
mode="min",
save_last=True,
),
]
trainer = pl.trainer.trainer.Trainer(
max_epochs=args.epochs,
callbacks=callbacks,
accelerator="gpu",
devices=args.gpus,
accumulate_grad_batches=1,
logger=logger,
gradient_clip_val=5,
check_val_every_n_epoch=1,
num_sanity_val_steps=0,
log_every_n_steps=1,
)
model = DNNBeamformer()
model_module = DNNBeamformerLightningModule(model)
data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)
trainer.fit(model_module, datamodule=data_module)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp/"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp/')",
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
help="Path to L3DAS22 datasets.",
required=True,
)
parser.add_argument(
"--batch_size",
default=4,
type=int,
help="Batch size for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=1,
type=int,
help="Number of GPUs per node to use for training. (Default: 1)",
)
parser.add_argument(
"--epochs",
default=100,
type=int,
help="Number of epochs to train for. (Default: 100)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
from typing import Dict, List, Tuple
import torch
from torch import Tensor
class CollateFnL3DAS22:
"""The collate class for L3DAS22 dataset.
Args:
pad (bool): If ``True``, the waveforms and labels will be padded to the
max length in the mini-batch. If ``pad`` is False, the waveforms
and labels will be cropped to the minimum length in the mini-batch.
(Default: False)
rand_crop (bool): if ``True``, the starting index of the waveform
and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
audio_length: int = 16000 * 4,
rand_crop: bool = True,
) -> None:
self.audio_length = audio_length
self.rand_crop = rand_crop
def __call__(self, batch: List[Tuple[Tensor, Tensor, int, str]]) -> Dict:
"""
Args:
batch (List[Tuple(Tensor, Tensor, int)]):
The list of tuples that contains:
- mixture waveforms
- clean waveform
- sample rate
- transcript
Returns:
Dictionary
"input": Tuple of waveforms and lengths.
waveforms Tensor with dimensions `(batch, time)`.
lengths Tensor with dimension `(batch,)`.
"label": None
"""
waveforms_noisy, waveforms_clean = [], []
for sample in batch:
waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript = sample
if self.rand_crop:
diff = waveform_noisy.size(-1) - self.audio_length
frame_offset = torch.randint(diff, size=(1,))
else:
frame_offset = 0
waveform_noisy = waveform_noisy[:, frame_offset : frame_offset + self.audio_length]
waveform_clean = waveform_clean[:, frame_offset : frame_offset + self.audio_length]
waveforms_noisy.append(waveform_noisy.unsqueeze(0))
waveforms_clean.append(waveform_clean)
waveforms_noisy = torch.cat(waveforms_noisy)
waveforms_clean = torch.cat(waveforms_clean)
return waveforms_noisy, waveforms_clean
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment