Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
import logging
import os
import shutil
from collections import defaultdict, deque
import torch
class MetricLogger:
r"""Logger for model metrics
"""
def __init__(self, group, print_freq=1):
self.print_freq = print_freq
self._iter = 0
self.data = defaultdict(lambda: deque(maxlen=self.print_freq))
self.data["group"].append(group)
def __setitem__(self, key, value):
self.data[key].append(value)
def _get_last(self):
return {k: v[-1] for k, v in self.data.items()}
def __str__(self):
return str(self._get_last())
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename):
r"""Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.info("Checkpoint: saved")
def count_parameters(model):
r"""Count the total number of parameters in the model
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# *****************************************************************************
# Copyright (c) 2019 fatchord (https://github.com/fatchord)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# *****************************************************************************
from torchaudio.models.wavernn import WaveRNN
import torch
import torchaudio
from torch import Tensor
from processing import normalized_waveform_to_bits
def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning channels
Eg: timesteps=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
Args:
x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
'''
_, channels, total_len = x.size()
# Calculate variables needed
n_folds = (total_len - overlap) // (timesteps + overlap)
extended_len = n_folds * (overlap + timesteps) + overlap
remaining = total_len - extended_len
# Pad if some time steps poking out
if remaining != 0:
n_folds += 1
padding = timesteps + 2 * overlap - remaining
x = torch.nn.functional.pad(x, (0, padding))
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
# Get the values for the folded tensor
for i in range(n_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
folded[i] = x[0, :, start:end]
return folded
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_timesteps, seq1_out],
[seq2_in, seq2_timesteps, seq2_out],
[seq3_in, seq3_timesteps, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
y (Tensor): Batched sequences of audio samples of size
(num_folds, channels, timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
'''
num_folds, channels, length = y.shape
timesteps = length - 2 * overlap
total_len = num_folds * (timesteps + overlap) + overlap
# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device)
linear = torch.ones((silence_len), dtype=y.dtype, device=y.device)
# Equal power crossfade
t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device)
fade_in = torch.sqrt(0.5 * (1 + t))
fade_out = torch.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
fade_in = torch.cat([silence, fade_in])
fade_out = torch.cat([linear, fade_out])
# Apply the gain to the overlap samples
y[:, :, :overlap] *= fade_in
y[:, :, -overlap:] *= fade_out
unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
unfolded[:, start:end] += y[i]
return unfolded
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def forward(self,
specgram: Tensor,
mulaw: bool = True,
batched: bool = True,
timesteps: int = 100,
overlap: int = 5) -> Tensor:
r"""Inference function for WaveRNN.
Based on the implementation from
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
Currently only supports multinomial sampling.
Args:
specgram (Tensor): spectrogram of size (n_mels, n_time)
mulaw (bool, optional): Whether to perform mulaw decoding (Default: ``True``).
batched (bool, optional): Whether to perform batch prediction. Using batch prediction
will significantly increase the inference speed (Default: ``True``).
timesteps (int, optional): The time steps for each batch. Only used when `batched`
is set to True (Default: ``100``).
overlap (int, optional): The overlapping time steps between batches. Only used when
`batched` is set to True (Default: ``5``).
Returns:
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
specgram = specgram.unsqueeze(0)
if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap)
output = self.wavernn_model.infer(specgram).cpu()
if mulaw:
output = normalized_waveform_to_bits(output, self.wavernn_model.n_bits)
output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)
if batched:
output = _xfade_and_unfold(output, overlap)
else:
output = output[0]
return output
# Source Separation Example
This directory contains reference implementations for source separations. For the detail of each model, please checkout the followings.
- [Conv-TasNet](./conv_tasnet/README.md)
## Usage
### Overview
To training a model, you can use [`lightning_train.py`](./lightning_train.py). This script takes the form of
`lightning_train.py [parameters]`
```
python lightning_train.py \
[--data-dir DATA_DIR] \
[--num-gpu NUM_GPU] \
[--num-workers NUM_WORKERS] \
...
# For the detail of the parameter values, use;
python lightning_train.py --help
```
This script runs training in PyTorch-Lightning framework with Distributed Data Parallel (DDP) backend.
### SLURM
<details><summary>Example scripts for running the training on SLURM cluster</summary>
- **launch_job.sh**
```bash
#!/bin/bash
#SBATCH --job-name=source_separation
#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out
#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=8
#SBATCH --mem-per-cpu=16G
#SBATCH --gpus-per-node=2
#srun env
srun wrapper.sh $@
```
- **wrapper.sh**
```bash
#!/bin/bash
num_speakers=2
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
exp_dir="/checkpoint/${USER}/exp/"
dataset_dir="/dataset/Libri${num_speakers}mix//wav8k/min"
mkdir -p "${exp_dir}"
python -u \
"${this_dir}/lightning_train.py" \
--num-speakers "${num_speakers}" \
--sample-rate 8000 \
--data-dir "${dataset_dir}" \
--exp-dir "${exp_dir}" \
--batch-size $((16 / SLURM_NTASKS))
```
</details>
# Conv-TasNet
This is a reference implementation of Conv-TasNet.
> Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation." IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.8 (2019): 1256-1266. Crossref. Web.
This implementation is based on [arXiv:1809.07454v3](https://arxiv.org/abs/1809.07454v3) and [the reference implementation](https://github.com/naplab/Conv-TasNet) provided by the authors.
For the usage, please checkout the [source separation README](../README.md).
## (Default) Training Configurations
The default training/model configurations follow the non-causal implementation from [Asteroid](https://github.com/asteroid-team/asteroid/tree/master/egs/librimix/ConvTasNet). (causal configuration is not implemented.)
- Sample rate: 8000 Hz
- Batch size: total 12 over distributed training workers
- Epochs: 200
- Initial learning rate: 1e-3
- Gradient clipping: maximum L2 norm of 5.0
- Optimizer: Adam
- Learning rate scheduling: Halved after 5 epochs of no improvement in validation accuracy.
- Objective function: SI-SNR
- Reported metrics: SI-SNRi, SDRi
- Sample audio length: 3 seconds (randomized position)
- Encoder/Decoder feature dimension (N): 512
- Encoder/Decoder convolution kernel size (L): 16
- TCN bottleneck/output feature dimension (B): 128
- TCN hidden feature dimension (H): 512
- TCN skip connection feature dimension (Sc): 128
- TCN convolution kernel size (P): 3
- The number of TCN convolution block layers (X): 8
- The number of TCN convolution blocks (R): 3
- The mask activation function: ReLU
## Evaluation
The following is the evaluation result of training the model on Libri2Mix dataset.
### LibirMix 2speakers
| | Si-SNRi (dB) | SDRi (dB) | Epoch |
|:-------------------:|-------------:|----------:|------:|
| Reference (Asteroid)| 14.7 | 15.1 | 200 |
| torchaudio | 15.3 | 15.6 | 200 |
from . import (
train,
trainer
)
__all__ = ['train', 'trainer']
#!/usr/bin/env python3
"""Train Conv-TasNet"""
import time
import pathlib
import argparse
import torch
import torchaudio
import torchaudio.models
import conv_tasnet
from utils import dist_utils
from utils.dataset import utils as dataset_utils
_LG = dist_utils.getLogger(__name__)
def _parse_args(args):
parser = argparse.ArgumentParser(description=__doc__,)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug behavior. Each epoch will end with just one batch.")
group = parser.add_argument_group("Model Options")
group.add_argument(
"--num-speakers", required=True, type=int, help="The number of speakers."
)
group = parser.add_argument_group("Dataset Options")
group.add_argument(
"--sample-rate",
required=True,
type=int,
help="Sample rate of audio files in the given dataset.",
)
group.add_argument(
"--dataset",
default="wsj0mix",
choices=["wsj0mix"],
help='Dataset type. (default: "wsj0mix")',
)
group.add_argument(
"--dataset-dir",
required=True,
type=pathlib.Path,
help=(
"Directory where dataset is found. "
'If the dataset type is "wsj9mix", then this is the directory where '
'"cv", "tt" and "tr" subdirectories are found.'
),
)
group = parser.add_argument_group("Save Options")
group.add_argument(
"--save-dir",
required=True,
type=pathlib.Path,
help=(
"Directory where the checkpoints and logs are saved. "
"Though, only the worker 0 saves checkpoint data, "
"all the worker processes must have access to the directory."
),
)
group = parser.add_argument_group("Dataloader Options")
group.add_argument(
"--batch-size",
type=int,
help="Batch size. (default: 16 // world_size)",
)
group = parser.add_argument_group("Training Options")
group.add_argument(
"--epochs",
metavar="NUM_EPOCHS",
default=100,
type=int,
help="The number of epochs to train. (default: 100)",
)
group.add_argument(
"--learning-rate",
default=1e-3,
type=float,
help="Initial learning rate. (default: 1e-3)",
)
group.add_argument(
"--grad-clip",
metavar="CLIP_VALUE",
default=5.0,
type=float,
help="Gradient clip value (l2 norm). (default: 5.0)",
)
group.add_argument(
"--resume",
metavar="CHECKPOINT_PATH",
help="Previous checkpoint file from which the training is resumed.",
)
args = parser.parse_args(args)
# Delaing the default value initialization until parse_args is done because
# if `--help` is given, distributed training is not enabled.
if args.batch_size is None:
args.batch_size = 16 // torch.distributed.get_world_size()
return args
def _get_model(
num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
):
model = torchaudio.models.ConvTasNet(
num_sources=num_sources,
enc_kernel_size=enc_kernel_size,
enc_num_feats=enc_num_feats,
msk_kernel_size=msk_kernel_size,
msk_num_feats=msk_num_feats,
msk_num_hidden_feats=msk_num_hidden_feats,
msk_num_layers=msk_num_layers,
msk_num_stacks=msk_num_stacks,
)
_LG.info_on_master("Model Configuration:")
_LG.info_on_master(" - N: %d", enc_num_feats)
_LG.info_on_master(" - L: %d", enc_kernel_size)
_LG.info_on_master(" - B: %d", msk_num_feats)
_LG.info_on_master(" - H: %d", msk_num_hidden_feats)
_LG.info_on_master(" - Sc: %d", msk_num_feats)
_LG.info_on_master(" - P: %d", msk_kernel_size)
_LG.info_on_master(" - X: %d", msk_num_layers)
_LG.info_on_master(" - R: %d", msk_num_stacks)
_LG.info_on_master(
" - Receptive Field: %s [samples]", model.mask_generator.receptive_field,
)
return model
def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_size, task=None):
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, task
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=4
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test')
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(train_dataset),
collate_fn=train_collate_fn,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(valid_dataset),
collate_fn=test_collate_fn,
pin_memory=True,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset),
collate_fn=test_collate_fn,
pin_memory=True,
)
return train_loader, valid_loader, eval_loader
def _write_header(log_path, args):
rows = [
[f"# torch: {torch.__version__}", ],
[f"# torchaudio: {torchaudio.__version__}", ]
]
rows.append(["# arguments"])
for key, item in vars(args).items():
rows.append([f"# {key}: {item}"])
dist_utils.write_csv_on_master(log_path, *rows)
def train(args):
args = _parse_args(args)
_LG.info("%s", args)
args.save_dir.mkdir(parents=True, exist_ok=True)
if "sox_io" in torchaudio.list_audio_backends():
torchaudio.set_audio_backend("sox_io")
start_epoch = 1
if args.resume:
checkpoint = torch.load(args.resume)
if args.sample_rate != checkpoint["sample_rate"]:
raise ValueError(
"The provided sample rate ({args.sample_rate}) does not match "
"the sample rate from the check point ({checkpoint['sample_rate']})."
)
if args.num_speakers != checkpoint["num_speakers"]:
raise ValueError(
"The provided #of speakers ({args.num_speakers}) does not match "
"the #of speakers from the check point ({checkpoint['num_speakers']}.)"
)
start_epoch = checkpoint["epoch"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_LG.info("Using: %s", device)
model = _get_model(num_sources=args.num_speakers)
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device] if torch.cuda.is_available() else None
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
if args.resume:
_LG.info("Loading parameters from the checkpoint...")
model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
else:
dist_utils.synchronize_params(
str(args.save_dir / "tmp.pt"), device, model, optimizer
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=3
)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.dataset_dir,
args.num_speakers,
args.sample_rate,
args.batch_size,
)
num_train_samples = len(train_loader.dataset)
num_valid_samples = len(valid_loader.dataset)
num_eval_samples = len(eval_loader.dataset)
_LG.info_on_master("Datasets:")
_LG.info_on_master(" - Train: %s", num_train_samples)
_LG.info_on_master(" - Valid: %s", num_valid_samples)
_LG.info_on_master(" - Eval: %s", num_eval_samples)
trainer = conv_tasnet.trainer.Trainer(
model,
optimizer,
train_loader,
valid_loader,
eval_loader,
args.grad_clip,
device,
debug=args.debug,
)
log_path = args.save_dir / "log.csv"
_write_header(log_path, args)
dist_utils.write_csv_on_master(
log_path,
[
"epoch",
"learning_rate",
"valid_si_snri",
"valid_sdri",
"eval_si_snri",
"eval_sdri",
],
)
_LG.info_on_master("Running %s epochs", args.epochs)
for epoch in range(start_epoch, start_epoch + args.epochs):
_LG.info_on_master("=" * 70)
_LG.info_on_master("Epoch: %s", epoch)
_LG.info_on_master("Learning rate: %s", optimizer.param_groups[0]["lr"])
_LG.info_on_master("=" * 70)
t0 = time.monotonic()
trainer.train_one_epoch()
train_sps = num_train_samples / (time.monotonic() - t0)
_LG.info_on_master("-" * 70)
t0 = time.monotonic()
valid_metric = trainer.validate()
valid_sps = num_valid_samples / (time.monotonic() - t0)
_LG.info_on_master("Valid: %s", valid_metric)
_LG.info_on_master("-" * 70)
t0 = time.monotonic()
eval_metric = trainer.evaluate()
eval_sps = num_eval_samples / (time.monotonic() - t0)
_LG.info_on_master(" Eval: %s", eval_metric)
_LG.info_on_master("-" * 70)
_LG.info_on_master("Train: Speed: %6.2f [samples/sec]", train_sps)
_LG.info_on_master("Valid: Speed: %6.2f [samples/sec]", valid_sps)
_LG.info_on_master(" Eval: Speed: %6.2f [samples/sec]", eval_sps)
_LG.info_on_master("-" * 70)
dist_utils.write_csv_on_master(
log_path,
[
epoch,
optimizer.param_groups[0]["lr"],
valid_metric.si_snri,
valid_metric.sdri,
eval_metric.si_snri,
eval_metric.sdri,
],
)
lr_scheduler.step(valid_metric.si_snri)
save_path = args.save_dir / f"epoch_{epoch}.pt"
dist_utils.save_on_master(
save_path,
{
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"num_speakers": args.num_speakers,
"sample_rate": args.sample_rate,
"epoch": epoch,
},
)
import time
from typing import Tuple
from collections import namedtuple
import torch
import torch.distributed as dist
from utils import dist_utils, metrics
_LG = dist_utils.getLogger(__name__)
Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = (
lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
)
def si_sdr_improvement(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
Returns:
torch.Tensor: Improved SI-SDR. Shape: [batch, ]
torch.Tensor: Absolute SI-SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
sdri = metrics.sdri(estimate, reference, mix, mask=mask)
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
mix = mix - mix.mean(axis=2, keepdim=True)
si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return si_sdri, sdri
class OccasionalLogger:
"""Simple helper class to log once in a while or when progress is quick enough"""
def __init__(self, time_interval=180, progress_interval=0.1):
self.time_interval = time_interval
self.progress_interval = progress_interval
self.last_time = 0.0
self.last_progress = 0.0
def log(self, metric, progress, force=False):
now = time.monotonic()
if (
force
or now > self.last_time + self.time_interval
or progress > self.last_progress + self.progress_interval
):
self.last_time = now
self.last_progress = progress
_LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)
class Trainer:
def __init__(
self,
model,
optimizer,
train_loader,
valid_loader,
eval_loader,
grad_clip,
device,
*,
debug,
):
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.valid_loader = valid_loader
self.eval_loader = eval_loader
self.grad_clip = grad_clip
self.device = device
self.debug = debug
def train_one_epoch(self):
self.model.train()
logger = OccasionalLogger()
num_batches = len(self.train_loader)
for i, batch in enumerate(self.train_loader, start=1):
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
si_snri = si_snri.mean()
sdri = sdri.mean()
loss = -si_snri
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip, norm_type=2.0
)
self.optimizer.step()
metric = Metric(si_snri.item(), sdri.item())
logger.log(metric, progress=i / num_batches, force=i == num_batches)
if self.debug:
break
def evaluate(self):
with torch.no_grad():
return self._test(self.eval_loader)
def validate(self):
with torch.no_grad():
return self._test(self.valid_loader)
def _test(self, loader):
self.model.eval()
total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device)
total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device)
for batch in loader:
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
total_si_snri += si_snri.sum()
total_sdri += sdri.sum()
if self.debug:
break
dist.all_reduce(total_si_snri, dist.ReduceOp.SUM)
dist.all_reduce(total_sdri, dist.ReduceOp.SUM)
num_samples = len(loader.dataset)
metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples)
return metric
from argparse import ArgumentParser
from pathlib import Path
from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval
import torch
def _eval(model, data_loader, device):
results = torch.zeros(4)
with torch.no_grad():
for _, batch in enumerate(data_loader):
mix, src, mask = batch
mix, src, mask = mix.to(device), src.to(device), mask.to(device)
est = model(mix)
sisdri = sisdri_metric(est, src, mix, mask)
src = src.cpu().detach().numpy()
est = est.cpu().detach().numpy()
mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy()
sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0])
sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0])
results += torch.tensor([
sdr.mean() - sdr_mix.mean(),
sisdri,
sir.mean() - sir_mix.mean(),
sar.mean() - sar_mix.mean()
])
results /= len(data_loader)
print("SDR improvement: ", results[0].item())
print("Si-SDR improvement: ", results[1].item())
print("SIR improvement: ", results[2].item())
print("SAR improvement: ", results[3].item())
def cli_main():
parser = ArgumentParser()
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
choices=["train-360", "train-100"],
help="The training partition of librimix dataset. (default: ``train-360``)",
)
parser.add_argument(
"--librimix-task",
default="sep_clean",
type=str,
choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
)
parser.add_argument(
"--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
)
parser.add_argument(
"--sample-rate",
default=8000,
type=int,
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--gpu-device",
default=-1,
type=int,
help="The gpu device for model inference. (default: -1)"
)
args = parser.parse_args()
model = _get_model(num_sources=2)
state_dict = torch.load(args.exp_dir / 'best_model.pth')
model.load_state_dict(state_dict)
if args.gpu_device != -1:
device = torch.device('cuda:' + str(args.gpu_device))
else:
device = torch.device('cpu')
model = model.to(device)
_, _, eval_loader = _get_dataloader(
args.dataset,
args.data_dir,
args.num_speakers,
args.sample_rate,
1, # batch size is set to 1 to avoid masking
0, # set num_workers to 0
args.librimix_task,
args.librimix_tr_split,
)
_eval(model, eval_loader, device)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
# pyre-strict
from pathlib import Path
from argparse import ArgumentParser
from typing import (
Any,
Callable,
Dict,
Mapping,
List,
Optional,
Tuple,
TypedDict,
Union,
)
import torch
import torchaudio
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.plugins import DDPPlugin
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from utils import metrics
from utils.dataset import utils as dataset_utils
class Batch(TypedDict):
mix: torch.Tensor # (batch, time)
src: torch.Tensor # (batch, source, time)
mask: torch.Tensor # (batch, source, time)
def sisdri_metric(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Compute the improvement of scale-invariant SDR. (SI-SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Tensor of dimension (batch, speakers == 1, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Improved SI-SDR. Tensor of dimension (batch, )
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
mix = mix - mix.mean(axis=2, keepdim=True)
si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return si_sdri.mean().item()
def sdri_metric(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Compute the improvement of SDR. (SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Tensor of dimension (batch, speakers == 1, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Improved SDR. Tensor of dimension (batch, )
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return sdri.mean().item()
def si_sdr_loss(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Compute the Si-SDR loss.
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Si-SDR loss. Tensor of dimension (batch, )
"""
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
si_sdri = metrics.sdr_pit(estimate, reference, mask=mask)
return -si_sdri.mean()
class ConvTasNetModule(LightningModule):
"""
The Lightning Module for speech separation.
Args:
model (Any): The model to use for the classification task.
train_loader (DataLoader): the training dataloader.
val_loader (DataLoader or None): the validation dataloader.
loss (Any): The loss function to use.
optim (Any): The optimizer to use.
metrics (List of methods): The metrics to track, which will be used for both train and validation.
lr_scheduler (Any or None): The LR Scheduler.
"""
def __init__(
self,
model: Any,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
loss: Any,
optim: Any,
metrics: List[Any],
lr_scheduler: Optional[Any] = None,
) -> None:
super().__init__()
self.model: nn.Module = model
self.loss: nn.Module = loss
self.optim: torch.optim.Optimizer = optim
self.lr_scheduler: Optional[_LRScheduler] = None
if lr_scheduler:
self.lr_scheduler = lr_scheduler
self.metrics: Mapping[str, Callable] = metrics
self.train_metrics: Dict = {}
self.val_metrics: Dict = {}
self.test_metrics: Dict = {}
self.save_hyperparameters()
self.train_loader = train_loader
self.val_loader = val_loader
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit":
self.train_metrics.update(self.metrics)
self.val_metrics.update(self.metrics)
else:
self.test_metrics.update(self.metrics)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward defines the prediction/inference actions.
"""
return self.model(x)
def training_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
return self._step(batch, batch_idx, "train")
def validation_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""
Operates on a single batch of data from the validation set.
"""
return self._step(batch, batch_idx, "val")
def test_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Optional[Dict[str, Any]]:
"""
Operates on a single batch of data from the test set.
"""
return self._step(batch, batch_idx, "test")
def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
"""
Common step for training, validation, and testing.
"""
mix, src, mask = batch
pred = self.model(mix)
loss = self.loss(pred, src, mask)
self.log(f"Losses/{phase_type}_loss", loss.item(), on_step=True, on_epoch=True)
metrics_result = self._compute_metrics(pred, src, mix, mask, phase_type)
self.log_dict(metrics_result, on_epoch=True)
return loss
def configure_optimizers(
self,
) -> Tuple[Any]:
lr_scheduler = self.lr_scheduler
if not lr_scheduler:
return self.optim
epoch_schedulers = {
'scheduler': lr_scheduler,
'monitor': 'Losses/val_loss',
'interval': 'epoch'
}
return [self.optim], [epoch_schedulers]
def _compute_metrics(
self,
pred: torch.Tensor,
label: torch.Tensor,
inputs: torch.Tensor,
mask: torch.Tensor,
phase_type: str,
) -> Dict[str, torch.Tensor]:
metrics_dict = getattr(self, f"{phase_type}_metrics")
metrics_result = {}
for name, metric in metrics_dict.items():
metrics_result[f"Metrics/{phase_type}/{name}"] = metric(pred, label, inputs, mask)
return metrics_result
def train_dataloader(self):
"""Training dataloader"""
return self.train_loader
def val_dataloader(self):
"""Validation dataloader"""
return self.val_loader
def _get_model(
num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
msk_activate="relu",
):
model = torchaudio.models.ConvTasNet(
num_sources=num_sources,
enc_kernel_size=enc_kernel_size,
enc_num_feats=enc_num_feats,
msk_kernel_size=msk_kernel_size,
msk_num_feats=msk_num_feats,
msk_num_hidden_feats=msk_num_hidden_feats,
msk_num_layers=msk_num_layers,
msk_num_stacks=msk_num_stacks,
msk_activate=msk_activate,
)
return model
def _get_dataloader(
dataset_type: str,
root_dir: Union[str, Path],
num_speakers: int = 2,
sample_rate: int = 8000,
batch_size: int = 6,
num_workers: int = 4,
librimix_task: Optional[str] = None,
librimix_tr_split: Optional[str] = None,
) -> Tuple[DataLoader]:
"""Get dataloaders for training, validation, and testing.
Args:
dataset_type (str): the dataset to use.
root_dir (str or Path): the root directory of the dataset.
num_speakers (int, optional): the number of speakers in the mixture. (Default: 2)
sample_rate (int, optional): the sample rate of the audio. (Default: 8000)
batch_size (int, optional): the batch size of the dataset. (Default: 6)
num_workers (int, optional): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
Returns:
tuple: (train_loader, valid_loader, eval_loader)
"""
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=3
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test', sample_rate=sample_rate)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=train_collate_fn,
num_workers=num_workers,
drop_last=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size,
collate_fn=test_collate_fn,
num_workers=num_workers,
drop_last=True,
)
eval_loader = DataLoader(
eval_dataset,
batch_size=batch_size,
collate_fn=test_collate_fn,
num_workers=num_workers,
)
return train_loader, valid_loader, eval_loader
def cli_main():
parser = ArgumentParser()
parser.add_argument("--batch-size", default=6, type=int)
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
choices=["train-360", "train-100"],
help="The training partition of librimix dataset. (default: ``train-360``)",
)
parser.add_argument(
"--librimix-task",
default="sep_clean",
type=str,
choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
)
parser.add_argument(
"--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
)
parser.add_argument(
"--sample-rate",
default=8000,
type=int,
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--epochs",
metavar="NUM_EPOCHS",
default=200,
type=int,
help="The number of epochs to train. (default: 200)",
)
parser.add_argument(
"--learning-rate",
default=1e-3,
type=float,
help="Initial learning rate. (default: 1e-3)",
)
parser.add_argument(
"--num-gpu",
default=1,
type=int,
help="The number of GPUs for training. (default: 1)",
)
parser.add_argument(
"--num-workers",
default=4,
type=int,
help="The number of workers for dataloader. (default: 4)",
)
args = parser.parse_args()
model = _get_model(num_sources=args.num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.root_dir,
args.num_speakers,
args.sample_rate,
args.batch_size,
args.num_workers,
args.librimix_task,
args.librimix_tr_split,
)
loss = si_sdr_loss
metric_dict = {
"sdri": sdri_metric,
"sisdri": sisdri_metric,
}
model = ConvTasNetModule(
model=model,
train_loader=train_loader,
val_loader=valid_loader,
loss=loss,
optim=optimizer,
metrics=metric_dict,
lr_scheduler=lr_scheduler,
)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True
)
callbacks = [
checkpoint,
EarlyStopping(monitor="Losses/val_loss", mode="min", patience=30, verbose=True),
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
gpus=args.num_gpu,
accelerator="ddp",
plugins=DDPPlugin(find_unused_parameters=False), # make sure there is no unused params
limit_train_batches=1.0, # Useful for fast experiment
gradient_clip_val=5.0,
callbacks=callbacks,
)
trainer.fit(model)
model.load_from_checkpoint(checkpoint.best_model_path)
state_dict = torch.load(checkpoint.best_model_path, map_location="cpu")
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
torch.save(state_dict, args.exp_dir / "best_model.pth")
trainer.test(model, eval_loader)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Launch souce separation training.
This script runs training in Distributed Data Parallel (DDP) framework and has two major
operation modes. This behavior depends on if `--worker-id` argument is given or not.
1. (`--worker-id` is not given) Launchs worker subprocesses that performs the actual training.
2. (`--worker-id` is given) Performs the training as a part of distributed training.
When launching the script without any distributed trainig parameters (operation mode 1),
this script will check the number of GPUs available on the local system and spawns the same
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
`--num-workers`. If there is no GPU available, only one subprocess is launched.
When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers.
"""
import sys
import logging
import argparse
import subprocess
import torch
from utils import dist_utils
_LG = dist_utils.getLogger(__name__)
def _parse_args(args=None):
max_world_size = torch.cuda.device_count() or 1
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument("--debug", action="store_true", help="Enable debug log")
group = parser.add_argument_group("Distributed Training")
group.add_argument(
"--worker-id",
type=int,
help=(
"If not provided, the launched process serves as a master process of "
"single-node, multi-worker training and spawns the worker subprocesses. "
"If provided, the launched process serves as a worker process, which "
"performs the actual training. The valid value is [0, --num-workers)."
),
)
group.add_argument(
"--device-id",
type=int,
help="The CUDA device ID. Allowed only when --worker-id is provided.",
)
group.add_argument(
"--num-workers",
type=int,
default=max_world_size,
help=(
"The size of distributed trainig workers. "
"If launching a training as single-node, multi-worker training, "
"(i.e. --worker-id is not provided) then this value should not exceed "
"the number of available GPUs. "
"If launching the training process as a multi-node, multi-gpu training, "
"(i.e. --worker-id is provided) then the value has to match "
f"the number of workers across nodes. (default: {max_world_size})"
),
)
group.add_argument(
"--sync-protocol",
type=str,
default="env://",
help=(
"Synchronization protocol for distributed training. "
"This value is passed as `init_method` argument of "
"`torch.distributed.init_process_group` function."
'If you are using `"env://"`, you can additionally configure '
'environment variables "MASTER_ADDR" and "MASTER_PORT". '
'If you are using `"file://..."`, then the process has to have '
"the access to the designated file. "
"See the documentation for `torch.distributed` for the detail. "
'If you are running the training in a single node, `"env://"` '
"should do. If you are running the training in multiple nodes, "
"you need to provide the file location where all the nodes have "
'access, using `"file://..."` protocol. (default: "env://")'
),
)
group.add_argument(
"--random-seed",
type=int,
help="Set random seed value. (default: None)",
)
parser.add_argument(
"rest", nargs=argparse.REMAINDER, help="Model-specific arguments."
)
namespace = parser.parse_args(args)
if namespace.worker_id is None:
if namespace.device_id is not None:
raise ValueError(
"`--device-id` cannot be provided when runing as master process."
)
if namespace.num_workers > max_world_size:
raise ValueError(
"--num-workers ({num_workers}) cannot exceed {device_count}."
)
if namespace.rest[:1] == ["--"]:
namespace.rest = namespace.rest[1:]
return namespace
def _main(cli_args):
args = _parse_args(cli_args)
if any(arg in ["--help", "-h"] for arg in args.rest):
_run_training(args.rest)
_init_logger(args.worker_id, args.debug)
if args.worker_id is None:
_run_training_subprocesses(args.num_workers, cli_args)
else:
dist_utils.setup_distributed(
world_size=args.num_workers,
rank=args.worker_id,
local_rank=args.device_id,
backend='nccl' if torch.cuda.is_available() else 'gloo',
init_method=args.sync_protocol,
)
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
if torch.cuda.is_available():
torch.cuda.set_device(args.device_id)
_LG.info("CUDA device set to %s", args.device_id)
_run_training(args.rest)
def _run_training_subprocesses(num_workers, original_args):
workers = []
_LG.info("Spawning %s workers", num_workers)
for i in range(num_workers):
worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"]
device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else []
command = (
[sys.executable, "-u", sys.argv[0]]
+ worker_arg
+ device_arg
+ original_args
)
_LG.info("Launching worker %s: `%s`", i, " ".join(command))
worker = subprocess.Popen(command)
workers.append(worker)
num_failed = 0
for worker in workers:
worker.wait()
if worker.returncode != 0:
num_failed += 1
sys.exit(num_failed)
def _run_training(args):
import conv_tasnet.train
conv_tasnet.train.train(args)
def _init_logger(rank=None, debug=False):
worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]"
message_fmt = (
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {worker_fmt} {message_fmt}",
)
if __name__ == "__main__":
_main(sys.argv[1:])
from . import (
dataset,
dist_utils,
metrics,
)
__all__ = ['dataset', 'dist_utils', 'metrics']
from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix']
from typing import List
from functools import partial
from collections import namedtuple
from torchaudio.datasets import LibriMix
import torch
from . import wsj0mix
Batch = namedtuple("Batch", ["mix", "src", "mask"])
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None):
if dataset_type == "wsj0mix":
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
elif dataset_type == "librimix":
train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task)
validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task)
evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task)
else:
raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False):
"""Ensure waveform has exact number of frames by slicing or padding"""
mix = sample[1] # [1, time]
src = torch.cat(sample[2], 0) # [num_sources, time]
num_channels, num_frames = src.shape
num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor')
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor')
if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
mix = mix[:, start_frame:]
src = src[:, start_frame:]
mix = mix[:, :target_num_frames]
src = src[:, :target_num_frames]
mask = torch.ones_like(mix)
else:
num_padding = target_num_frames - num_frames
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
mix = torch.cat([mix, pad], 1)
src = torch.cat([src, pad.expand(num_channels, -1)], 1)
mask = torch.ones_like(mix)
mask[..., num_frames:] = 0
return mix, src, mask
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
target_num_frames = int(duration * sample_rate)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
max_num_frames = max(s[1].shape[-1] for s in samples)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type in ["wsj0mix", "librimix"]:
if mode == 'train':
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate)
raise ValueError(f"Unexpected dataset: {dataset_type}")
from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
class WSJ0Mix(Dataset):
"""Create a Dataset for wsj0-mix.
Args:
root (str or Path): Path to the directory where the dataset is found.
num_speakers (int): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios.
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``.
audio_ext (str, optional): The extension of audio files to find. (default: ".wav")
"""
def __init__(
self,
root: Union[str, Path],
num_speakers: int,
sample_rate: int,
audio_ext: str = ".wav",
):
self.root = Path(root)
self.sample_rate = sample_rate
self.mix_dir = (self.root / "mix").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")]
self.files.sort()
def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}, "
f"but the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src)
return self.sample_rate, mixed, srcs
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return self._load_sample(self.files[key])
import os
import csv
import types
import logging
import torch
import torch.distributed as dist
def _info_on_master(self, *args, **kwargs):
if dist.get_rank() == 0:
self.info(*args, **kwargs)
def getLogger(name):
"""Get logging.Logger module with additional ``info_on_master`` method."""
logger = logging.getLogger(name)
logger.info_on_master = types.MethodType(_info_on_master, logger)
return logger
_LG = getLogger(__name__)
def setup_distributed(
world_size, rank, local_rank, backend="nccl", init_method="env://"
):
"""Perform env setup and initialization for distributed training"""
if init_method == "env://":
_set_env_vars(world_size, rank, local_rank)
if world_size > 1 and "OMP_NUM_THREADS" not in os.environ:
_LG.info("Setting OMP_NUM_THREADS == 1")
os.environ["OMP_NUM_THREADS"] = "1"
params = {
"backend": backend,
"init_method": init_method,
"world_size": world_size,
"rank": rank,
}
_LG.info("Initializing distributed process group with %s", params)
dist.init_process_group(**params)
_LG.info("Initialized distributed process group.")
def _set_env_vars(world_size, rank, local_rank):
for key, default in [("MASTER_ADDR", "127.0.0.1"), ("MASTER_PORT", "29500")]:
if key not in os.environ:
os.environ[key] = default
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)
def save_on_master(path, obj):
if dist.get_rank() == 0:
_LG.info("Saving %s", path)
torch.save(obj, path)
def write_csv_on_master(path, *rows):
if dist.get_rank() == 0:
with open(path, "a", newline="") as fileobj:
writer = csv.writer(fileobj)
for row in rows:
writer.writerow(row)
def synchronize_params(path, device, *modules):
if dist.get_world_size() < 2:
return
rank = dist.get_rank()
if rank == 0:
_LG.info("[Parameter Sync]: Saving parameters to a temp file...")
torch.save({f"{i}": m.state_dict() for i, m in enumerate(modules)}, path)
dist.barrier()
if rank != 0:
_LG.info("[Parameter Sync]: Loading parameters...")
data = torch.load(path, map_location=device)
for i, m in enumerate(modules):
m.load_state_dict(data[f"{i}"])
dist.barrier()
if rank == 0:
_LG.info("[Parameter Sync]: Removing the temp file...")
os.remove(path)
_LG.info_on_master("[Parameter Sync]: Complete.")
import math
from typing import Optional
from itertools import permutations
import torch
def sdr(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Computes source-to-distortion ratio.
1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
2. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float, optional): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L34-L56
"""
reference_pow = reference.pow(2).mean(axis=2, keepdim=True)
mix_pow = (estimate * reference).mean(axis=2, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
if mask is None:
reference_pow = reference_pow.mean(axis=2)
error_pow = error_pow.mean(axis=2)
else:
denom = mask.sum(axis=2)
reference_pow = (mask * reference_pow).sum(axis=2) / denom
error_pow = (mask * error_pow).sum(axis=2) / denom
return 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
class PIT(torch.nn.Module):
"""Applies utterance-level speaker permutation
Computes the maxium possible value of the given utility function
over the permutations of the speakers.
Args:
utility_func (function):
Function that computes the utility (opposite of loss) with signature of
(extimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor
where input Tensors are shape of [batch, speakers, frame] and
the output Tensor is shape of [batch, speakers].
References:
- Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of
Deep Recurrent Neural Networks
Morten Kolbæk, Dong Yu, Zheng-Hua Tan and Jesper Jensen
https://arxiv.org/abs/1703.06284
"""
def __init__(self, utility_func):
super().__init__()
self.utility_func = utility_func
def forward(
self,
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Compute utterance-level PIT Loss
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [bacth, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float, optional): constant value used to stabilize division.
Returns:
torch.Tensor: Maximum criterion over the speaker permutation.
Shape: [batch, ]
"""
assert estimate.shape == reference.shape
batch_size, num_speakers = reference.shape[:2]
num_permute = math.factorial(num_speakers)
util_mat = torch.zeros(
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon)
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
return util_mat.max(dim=1).values
_sdr_pit = PIT(utility_func=sdr)
def sdr_pit(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
2. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
3. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float, optional): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as the reference implementation,
*when the inputs have 0-mean*
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153
"""
return _sdr_pit(estimate, reference, mask, epsilon)
def sdri(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi).
This function compute how much SDR is improved if the estimation is changed from
the original mixture signal to the actual estimated source signals. That is,
``SDR(estimate, reference) - SDR(mix, reference)``.
For computing ``SDR(estimate, reference)``, PIT (permutation invariant training) is applied,
so that best combination of sources between the reference signals and the esimate signals
are picked.
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float, optional): constant value used to stabilize division.
Returns:
torch.Tensor: Improved SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ]
base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker]
return sdr_ - base_sdr.mean(dim=1)
import argparse
import logging
import os
import unittest
from interactive_asr.utils import setup_asr, transcribe_file
class ASRTest(unittest.TestCase):
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
arguments_dict = {
"path": "/scratch/jamarshon/downloads/model.pt",
"input_file": "/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav",
"data": "/scratch/jamarshon/downloads",
"user_dir": "/scratch/jamarshon/fairseq-py/examples/speech_recognition",
"no_progress_bar": False,
"log_interval": 1000,
"log_format": None,
"tensorboard_logdir": "",
"tbmf_wrapper": False,
"seed": 1,
"cpu": True,
"fp16": False,
"memory_efficient_fp16": False,
"fp16_init_scale": 128,
"fp16_scale_window": None,
"fp16_scale_tolerance": 0.0,
"min_loss_scale": 0.0001,
"threshold_loss_scale": None,
"criterion": "cross_entropy",
"tokenizer": None,
"bpe": None,
"optimizer": "nag",
"lr_scheduler": "fixed",
"task": "speech_recognition",
"num_workers": 0,
"skip_invalid_size_inputs_valid_test": False,
"max_tokens": 10000000,
"max_sentences": None,
"required_batch_size_multiple": 8,
"dataset_impl": None,
"gen_subset": "test",
"num_shards": 1,
"shard_id": 0,
"remove_bpe": None,
"quiet": False,
"model_overrides": "{}",
"results_path": None,
"beam": 40,
"nbest": 1,
"max_len_a": 0,
"max_len_b": 200,
"min_len": 1,
"match_source_len": False,
"no_early_stop": False,
"unnormalized": False,
"no_beamable_mm": False,
"lenpen": 1,
"unkpen": 0,
"replace_unk": None,
"sacrebleu": False,
"score_reference": False,
"prefix_size": 0,
"no_repeat_ngram_size": 0,
"sampling": False,
"sampling_topk": -1,
"sampling_topp": -1.0,
"temperature": 1.0,
"diverse_beam_groups": -1,
"diverse_beam_strength": 0.5,
"print_alignment": False,
"ctc": False,
"rnnt": False,
"kspmodel": None,
"wfstlm": None,
"rnnt_decoding_type": "greedy",
"lm_weight": 0.2,
"rnnt_len_penalty": -0.5,
"momentum": 0.99,
"weight_decay": 0.0,
"force_anneal": None,
"lr_shrink": 0.1,
"warmup_updates": 0,
}
arguments_dict["path"] = os.environ.get("ASR_MODEL_PATH", None)
arguments_dict["input_file"] = os.environ.get("ASR_INPUT_FILE", None)
arguments_dict["data"] = os.environ.get("ASR_DATA_PATH", None)
arguments_dict["user_dir"] = os.environ.get("ASR_USER_DIR", None)
args = argparse.Namespace(**arguments_dict)
def test_transcribe_file(self):
task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger)
_, transcription = transcribe_file(
self.args, task, generator, models, sp, tgt_dict
)
expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]]
self.assertEqual(transcription, expected_transcription, msg=str(transcription))
if __name__ == "__main__":
unittest.main()
[mypy]
allow_redefinition = True
ignore_missing_imports = True
# Building torchaudio packages for release
## Anaconda packages
### Linux
```bash
docker run -it --ipc=host --rm -v $(pwd):/remote soumith/conda-cuda bash
cd remote
PYTHON_VERSION=3.7 packaging/build_conda.sh
```
To install bz2,
```bash
cd /opt/conda/conda-bld/linux-64/
# install dependencies
conda install pytorch-cpu=1.1.0
conda install sox
# install torchaudio
conda install /opt/conda/conda-bld/linux-64/torchaudio-cpu-0.2.0-py27_1.tar.bz2
```
To upload bz2,
```bash
anaconda upload -u pytorch /opt/conda/conda-bld/linux-64/torchaudio*.bz2
```
### OSX
```bash
# create a fresh anaconda environment / install and activate it
PYTHON_VERSION=3.7 packaging/build_conda.sh
```
To install bz2,
```bash
cd /Users/jamarshon/anaconda3/conda-bld/osx-64/
# activate conda env (e.g
conda info --envs
conda activate /Users/jamarshon/minconda_wheel_env_tmp/envs/env2.7
# install dependencies
conda install pytorch-cpu=1.1.0
conda install sox
# install torchaudio
# and then try installing (e.g
conda install /Users/jamarshon/anaconda3/conda-bld/osx-64/torchaudio-0.2.0-py27_1.tar.bz2
```
To upload bz2,
```bash
anaconda upload -u pytorch /Users/jamarshon/anaconda3/conda-bld/osx-64/torchaudio*.bz2
```
## Wheels
### Linux
```bash
nvidia-docker run -it --ipc=host --rm -v $(pwd):/remote soumith/manylinux-cuda90:latest bash
cd remote
PYTHON_VERSION=3.7 packaging/build_wheel.sh
```
To install wheels,
```bash
cd ../cpu
/opt/python/cp35-cp35m/bin/pip install torchaudio-0.2-cp35-cp35m-linux_x86_64.whl
```
To upload wheels,
```bash
cd ../cpu
/opt/python/cp35-cp35m/bin/pip install twine
/opt/python/cp35-cp35m/bin/twine upload *.whl
```
### OSX
```bash
PYTHON_VERSION=3.7 packaging/build_wheel.sh
```
To install wheels,
```bash
cd ~/torchaudio_wheels
conda activate /Users/jamarshon/minconda_wheel_env_tmp/envs/env2.7
pip install torchaudio-0.2-cp27-cp27m-macosx_10_6_x86_64.whl
```
To upload wheels,
```bash
pip install twine
cd ~/torchaudio_wheels
twine upload *.whl
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment