Commit 5656d5d4 authored by nateanl's avatar nateanl Committed by Zhaoheng Ni
Browse files

Training recipe for ConvTasNet on Libri2Mix dataset. (#1757)

parent 5b1cd9a6
...@@ -8,105 +8,23 @@ This directory contains reference implementations for source separations. For th ...@@ -8,105 +8,23 @@ This directory contains reference implementations for source separations. For th
### Overview ### Overview
To training a model, you can use [`train.py`](./train.py). This script takes the form of To training a model, you can use [`lightning_train.py`](./lightning_train.py). This script takes the form of
`train.py [parameters for distributed training] -- [parameters for model/training]` `lightning_train.py [parameters]`
``` ```
python train.py \ python lightning_train.py \
[--worker-id WORKER_ID] \ [--data-dir DATA_DIR] \
[--device-id DEVICE_ID] \ [--num-gpu NUM_GPU] \
[--num-workers NUM_WORKERS] \ [--num-workers NUM_WORKERS] \
[--sync-protocol SYNC_PROTOCOL] \ ...
-- \
<model specific training parameters>
# For the detail of the parameter values, use; # For the detail of the parameter values, use;
python train.py --help python lightning_train.py --help
# For the detail of the model parameters, use;
python train.py -- --help
``` ```
If you would like to just try out the traing script, then try it without any parameters This script runs training in PyTorch-Lightning framework with Distributed Data Parallel (DDP) backend.
for distributed training. `train.py -- --sample-rate 8000 --batch-size <BATCH_SIZE> --dataset-dir <DATASET_DIR> --save-dir <SAVE_DIR>`
This script runs training in Distributed Data Parallel (DDP) framework and has two major
operation modes. This behavior depends on if `--worker-id` argument is given or not.
1. (`--worker-id` is not given) Launchs training worker subprocesses that performs the actual training.
2. (`--worker-id` is given) Performs the training as a part of distributed training.
When launching the script without any distributed trainig parameters (operation mode 1),
this script will check the number of GPUs available on the local system and spawns the same
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
`--num-workers`. If there is no GPU available, only one subprocess is launched and providing
`--num-workers` larger than 1 results in error.
When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers.
- `--num-workers` is the number of training processes being launched.
- `--worker-id` is the process rank (must be unique across all the processes).
- `--device-id` is the GPU device ID (should be unique within node).
- `--sync-protocl` is how each worker process communicate and synchronize.
If the training is carried out on a single node, then the default `"env://"` should do.
If the training processes span across multiple nodes, then you need to provide a protocol that
can communicate over the network. If you know where the master node is located, you can use
`"env://"` in combination with `MASTER_ADDR` and `MASER_PORT` environment variables. If you do
not know where the master node is located beforehand, you can use `"file://..."` protocol to
indicate where the file to which all the worker process have access is located. For other
protocols, please refer to the official documentation.
### Distributed Training Notes
<details><summary>Quick overview on DDP (distributed data parallel)</summary>
DDP is single-program multiple-data training paradigm.
With DDP, the model is replicated on every process,
and every model replica will be fed with a different set of input data samples.
- **Process**: Worker process (as in Linux process). There are `P` processes per a Node.
- **Node**: A machine. There are `N` machines, each of which holds `P` processes.
- **World**: network of nodes, composed of `N` nodes and `N * P` processes.
- **Rank**: Grobal process ID (unique across nodes) `[0, N * P)`
- **Local Rank**: Local process ID (unique only within a node) `[0, P)`
```
Node 0 Node 1 Node N-1
┌────────────────────────┐┌─────────────────────────┐ ┌───────────────────────────┐
│╔══════════╗ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
│║ Process ╟─┤ GPU: 0 ││││ Process ├─┤ GPU: 0 ││ ││ Process ├─┤ GPU: 0 ││
│║ Rank: 0 ║ └─────────┘│││ Rank:P │ └─────────┘│ ││ Rank:NP-P │ └─────────┘│
│╚══════════╝ ││└───────────┘ │ │└─────────────┘ │
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
││ Process ├─┤ GPU: 1 ││││ Process ├─┤ GPU: 1 ││ ││ Process ├─┤ GPU: 1 ││
││ Rank: 1 │ └─────────┘│││ Rank:P+1 │ └─────────┘│ ││ Rank:NP-P+1 │ └─────────┘│
│└──────────┘ ││└───────────┘ │ ... │└─────────────┘ │
│ ││ │ │ │
│ ... ││ ... │ │ ... │
│ ││ │ │ │
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
││ Process ├─┤ GPU:P-1 ││││ Process ├─┤ GPU:P-1 ││ ││ Process ├─┤ GPU:P-1 ││
││ Rank:P-1 │ └─────────┘│││ Rank:2P-1 │ └─────────┘│ ││ Rank:NP-1 │ └─────────┘│
│└──────────┘ ││└───────────┘ │ │└─────────────┘ │
└────────────────────────┘└─────────────────────────┘ └───────────────────────────┘
```
</details>
### SLURM ### SLURM
When launched as SLURM job, the follwoing environment variables correspond to
- **SLURM_PROCID**: `--worker-id` (Rank)
- **SLURM_NTASKS** (or legacy **SLURM_NPPROCS**): the number of total processes (`--num-workers` == world size)
- **SLURM_LOCALID**: Local Rank (to be mapped with GPU index*)
* Even when GPU resource is allocated with `--gpus-per-task=1`, if there are muptiple
tasks allocated on the same node, (thus multiple GPUs of the node are allocated to the job)
each task can see all the GPUs allocated for the tasks. Therefore we need to use
SLURM_LOCALID to tell each task to which GPU it should be using.
<details><summary>Example scripts for running the training on SLURM cluster</summary> <details><summary>Example scripts for running the training on SLURM cluster</summary>
- **launch_job.sh** - **launch_job.sh**
...@@ -122,13 +40,13 @@ SLURM_LOCALID to tell each task to which GPU it should be using. ...@@ -122,13 +40,13 @@ SLURM_LOCALID to tell each task to which GPU it should be using.
#SBATCH --nodes=1 #SBATCH --nodes=1
#SBATCH --ntasks-per-node=8 #SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=8 #SBATCH --cpus-per-task=8
#SBATCH --mem-per-cpu=16G #SBATCH --mem-per-cpu=16G
#SBATCH --gpus-per-task=1 #SBATCH --gpus-per-node=2
#srun env #srun env
srun wrapper.sh $@ srun wrapper.sh $@
...@@ -140,28 +58,18 @@ srun wrapper.sh $@ ...@@ -140,28 +58,18 @@ srun wrapper.sh $@
#!/bin/bash #!/bin/bash
num_speakers=2 num_speakers=2
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
save_dir="/checkpoint/${USER}/jobs/${SLURM_JOB_NAME}/${SLURM_JOB_ID}" exp_dir="/checkpoint/${USER}/exp/"
dataset_dir="/dataset/wsj0-mix/${num_speakers}speakers/wav8k/min" dataset_dir="/dataset/Libri${num_speakers}mix//wav8k/min"
if [ "${SLURM_JOB_NUM_NODES}" -gt 1 ]; then
protocol="file:///checkpoint/${USER}/jobs/source_separation/${SLURM_JOB_ID}/sync"
else
protocol="env://"
fi
mkdir -p "${save_dir}" mkdir -p "${exp_dir}"
python -u \ python -u \
"${this_dir}/train.py" \ "${this_dir}/lightning_train.py" \
--worker-id "${SLURM_PROCID}" \
--num-workers "${SLURM_NTASKS}" \
--device-id "${SLURM_LOCALID}" \
--sync-protocol "${protocol}" \
-- \
--num-speakers "${num_speakers}" \ --num-speakers "${num_speakers}" \
--sample-rate 8000 \ --sample-rate 8000 \
--dataset-dir "${dataset_dir}" \ --data-dir "${dataset_dir}" \
--save-dir "${save_dir}" \ --exp-dir "${exp_dir}" \
--batch-size $((16 / SLURM_NTASKS)) --batch-size $((16 / SLURM_NTASKS))
``` ```
......
...@@ -10,18 +10,18 @@ For the usage, please checkout the [source separation README](../README.md). ...@@ -10,18 +10,18 @@ For the usage, please checkout the [source separation README](../README.md).
## (Default) Training Configurations ## (Default) Training Configurations
The default training/model configurations follow the best non-causal implementation from the paper. (causal configuration is not implemented.) The default training/model configurations follow the non-causal implementation from [Asteroid](https://github.com/asteroid-team/asteroid/tree/master/egs/librimix/ConvTasNet). (causal configuration is not implemented.)
- Sample rate: 8000 Hz - Sample rate: 8000 Hz
- Batch size: total 16 over distributed training workers - Batch size: total 12 over distributed training workers
- Epochs: 100 - Epochs: 200
- Initial learning rate: 1e-3 - Initial learning rate: 1e-3
- Gradient clipping: maximum L2 norm of 5.0 - Gradient clipping: maximum L2 norm of 5.0
- Optimizer: Adam - Optimizer: Adam
- Learning rate scheduling: Halved after 3 epochs of no improvement in validation accuracy. - Learning rate scheduling: Halved after 5 epochs of no improvement in validation accuracy.
- Objective function: SI-SNRi - Objective function: SI-SNR
- Reported metrics: SI-SNRi, SDRi - Reported metrics: SI-SNRi, SDRi
- Sample audio length: 4 seconds (randomized position) - Sample audio length: 3 seconds (randomized position)
- Encoder/Decoder feature dimension (N): 512 - Encoder/Decoder feature dimension (N): 512
- Encoder/Decoder convolution kernel size (L): 16 - Encoder/Decoder convolution kernel size (L): 16
- TCN bottleneck/output feature dimension (B): 128 - TCN bottleneck/output feature dimension (B): 128
...@@ -30,23 +30,15 @@ The default training/model configurations follow the best non-causal implementat ...@@ -30,23 +30,15 @@ The default training/model configurations follow the best non-causal implementat
- TCN convolution kernel size (P): 3 - TCN convolution kernel size (P): 3
- The number of TCN convolution block layers (X): 8 - The number of TCN convolution block layers (X): 8
- The number of TCN convolution blocks (R): 3 - The number of TCN convolution blocks (R): 3
- The mask activation function: ReLU
## Evaluation ## Evaluation
The following is the evaluation result of training the model on WSJ0-2mix and WSJ0-3mix datasets. The following is the evaluation result of training the model on Libri2Mix dataset.
### wsj0-mix 2speakers ### LibirMix 2speakers
| | SI-SNRi (dB) | SDRi (dB) | Epoch | | | Si-SNRi (dB) | SDRi (dB) | Epoch |
|:------------------:|-------------:|----------:|------:| |:-------------------:|-------------:|----------:|------:|
| Reference | 15.3 | 15.6 | | | Reference (Asteroid)| 14.7 | 15.1 | 200 |
| Validation dataset | 13.1 | 13.1 | 100 | | torchaudio | 15.3 | 15.6 | 200 |
| Evaluation dataset | 11.0 | 11.0 | 100 |
### wsj0-mix 3speakers
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
|:------------------:|-------------:|----------:|------:|
| Reference | 12.7 | 13.1 | |
| Validation dataset | 11.4 | 11.4 | 100 |
| Evaluation dataset | 8.9 | 8.9 | 100 |
from . import ( from . import (
train, train,
trainer, trainer
) )
__all__ = ['train', 'trainer'] __all__ = ['train', 'trainer']
...@@ -137,9 +137,9 @@ def _get_model( ...@@ -137,9 +137,9 @@ def _get_model(
return model return model
def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_size): def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_size, task=None):
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, dataset_type, dataset_dir, num_speakers, sample_rate, task
) )
train_collate_fn = dataset_utils.get_collate_fn( train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=4 dataset_type, mode='train', sample_rate=sample_rate, duration=4
......
from argparse import ArgumentParser
import pathlib
from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval
import torch
def eval(model, data_loader, device):
results = torch.zeros(4)
with torch.no_grad():
for i, batch in enumerate(data_loader):
mix, src, mask = batch
mix, src, mask = mix.to(device), src.to(device), mask.to(device)
est = model(mix)
sisdri = sisdri_metric(est, src, mix, mask)
src = src.cpu().detach().numpy()
est = est.cpu().detach().numpy()
mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy()
sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0])
sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0])
results += torch.tensor([
sdr.mean() - sdr_mix.mean(),
sisdri,
sir.mean() - sir_mix.mean(),
sar.mean() - sar_mix.mean()
])
results /= len(data_loader)
print("SDR improvement: ", results[0].item())
print("Si-SDR improvement: ", results[1].item())
print("SIR improvement: ", results[2].item())
print("SAR improvement: ", results[3].item())
def cli_main():
parser = ArgumentParser()
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
choices=["train-360", "train-100"],
help="The training partition of librimix dataset. (default: ``train-360``)",
)
parser.add_argument(
"--librimix-task",
default="sep_clean",
type=str,
choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
)
parser.add_argument(
"--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
)
parser.add_argument(
"--sample-rate",
default=8000,
type=int,
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--gpu-device",
default=-1,
type=int,
help="The gpu device for model inference. (default: -1)"
)
args = parser.parse_args()
model = _get_model(num_sources=2)
state_dict = torch.load(args.exp_dir / 'best_model.pth')
model.load_state_dict(state_dict)
if args.gpu_device != -1:
device = torch.device('cuda:' + str(args.gpu_device))
else:
device = torch.device('cpu')
model = model.to(device)
_, _, eval_loader = _get_dataloader(
args.dataset,
args.data_dir,
args.num_speakers,
args.sample_rate,
1, # batch size is set to 1 to avoid masking
0, # set num_workers to 0
args.librimix_task,
args.librimix_tr_split,
)
eval(model, eval_loader, device)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
# pyre-strict
import pathlib
from argparse import ArgumentParser
from typing import (
Any,
Callable,
Dict,
Mapping,
List,
Optional,
Tuple,
TypedDict,
)
import torch
import torchaudio
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.plugins import DDPPlugin
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from utils import metrics
from utils.dataset import utils as dataset_utils
class Batch(TypedDict):
mix: torch.Tensor # (batch, time)
src: torch.Tensor # (batch, source, time)
mask: torch.Tensor # (batch, source, time)
def sisdri_metric(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Compute the improvement of scale-invariant SDR. (SI-SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Tensor of dimension (batch, speakers == 1, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Improved SI-SDR. Tensor of dimension (batch, )
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
mix = mix - mix.mean(axis=2, keepdim=True)
si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return si_sdri.mean().item()
def sdri_metric(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Compute the improvement of SDR. (SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Tensor of dimension (batch, speakers == 1, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Improved SDR. Tensor of dimension (batch, )
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return sdri.mean().item()
def si_sdr_loss(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Compute the Si-SDR loss.
Args:
estimate (torch.Tensor): Estimated source signals.
Tensor of dimension (batch, speakers, time)
reference (torch.Tensor): Reference (original) source signals.
Tensor of dimension (batch, speakers, time)
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Tensor of dimension (batch, 1, time)
Returns:
torch.Tensor: Si-SDR loss. Tensor of dimension (batch, )
"""
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
si_sdri = metrics.sdr_pit(estimate, reference, mask=mask)
return -si_sdri.mean()
class ConvTasNetModule(LightningModule):
"""
The Lightning Module for speech separation.
Args:
model (Any): The model to use for the classification task.
train_loader (DataLoader): the training dataloader.
val_loader (DataLoader or None): the validation dataloader.
loss (Any): The loss function to use.
optim (Any): The optimizer to use.
metrics (List of methods): The metrics to track, which will be used for both train and validation.
lr_scheduler (Any or None): The LR Scheduler.
"""
def __init__(
self,
model: Any,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
loss: Any,
optim: Any,
metrics: List[Any],
lr_scheduler: Optional[Any] = None,
) -> None:
super().__init__()
self.model: nn.Module = model
self.loss: nn.Module = loss
self.optim: torch.optim.Optimizer = optim
self.lr_scheduler: Optional[_LRScheduler] = None
if lr_scheduler:
self.lr_scheduler = lr_scheduler
self.metrics: Mapping[str, Callable] = metrics
self.train_metrics: Dict = {}
self.val_metrics: Dict = {}
self.test_metrics: Dict = {}
self.save_hyperparameters()
self.train_loader = train_loader
self.val_loader = val_loader
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit":
self.train_metrics.update(self.metrics)
self.val_metrics.update(self.metrics)
else:
self.test_metrics.update(self.metrics)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward defines the prediction/inference actions.
"""
return self.model(x)
def training_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
return self._step(batch, batch_idx, "train")
def validation_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""
Operates on a single batch of data from the validation set.
"""
return self._step(batch, batch_idx, "val")
def test_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Optional[Dict[str, Any]]:
"""
Operates on a single batch of data from the test set.
"""
return self._step(batch, batch_idx, "test")
def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
"""
Common step for training, validation, and testing.
"""
mix, src, mask = batch
pred = self.model(mix)
loss = self.loss(pred, src, mask)
self.log(f"Losses/{phase_type}_loss", loss.item(), on_step=True, on_epoch=True)
metrics_result = self._compute_metrics(pred, src, mix, mask, phase_type)
self.log_dict(metrics_result, on_epoch=True)
return loss
def configure_optimizers(
self,
) -> Tuple[Any]:
lr_scheduler = self.lr_scheduler
if not lr_scheduler:
return self.optim
epoch_schedulers = {
'scheduler': lr_scheduler,
'monitor': 'Losses/val_loss',
'interval': 'epoch'
}
return [self.optim], [epoch_schedulers]
def _compute_metrics(
self,
pred: torch.Tensor,
label: torch.Tensor,
inputs: torch.Tensor,
mask: torch.Tensor,
phase_type: str,
) -> Dict[str, torch.Tensor]:
metrics_dict = getattr(self, f"{phase_type}_metrics")
metrics_result = {}
for name, metric in metrics_dict.items():
metrics_result[f"Metrics/{phase_type}/{name}"] = metric(pred, label, inputs, mask)
return metrics_result
def train_dataloader(self):
"""Training dataloader"""
return self.train_loader
def val_dataloader(self):
"""Validation dataloader"""
return self.val_loader
def _get_model(
num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
msk_activate="relu",
):
model = torchaudio.models.ConvTasNet(
num_sources=num_sources,
enc_kernel_size=enc_kernel_size,
enc_num_feats=enc_num_feats,
msk_kernel_size=msk_kernel_size,
msk_num_feats=msk_num_feats,
msk_num_hidden_feats=msk_num_hidden_feats,
msk_num_layers=msk_num_layers,
msk_num_stacks=msk_num_stacks,
msk_activate=msk_activate,
)
return model
def _get_dataloader(
dataset_type: str,
dataset_dir: pathlib.Path,
num_speakers: int = 2,
sample_rate: int = 8000,
batch_size: int = 6,
num_workers: int = 4,
librimix_task: Optional[str] = None,
librimix_tr_split: Optional[str] = None,
) -> Tuple[DataLoader]:
"""Get dataloaders for training, validation, and testing.
Args:
dataset_type (str): the dataset to use.
dataset_dir (pathlib.Path): the root directory of the dataset.
num_speakers (int): the number of speakers in the mixture. (Default: 2)
sample_rate (int): the sample rate of the audio. (Default: 8000)
batch_size (int): the batch size of the dataset. (Default: 6)
num_workers (int): the number of workers for each dataloader. (Default: 4)
librimix_task (str or None, optional): the task in LibriMix dataset.
librimix_tr_split (str or None, optional): the training split in LibriMix dataset.
Returns:
tuple: (train_loader, valid_loader, eval_loader)
"""
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=3
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test', sample_rate=sample_rate)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=train_collate_fn,
num_workers=num_workers,
drop_last=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size,
collate_fn=test_collate_fn,
num_workers=num_workers,
drop_last=True,
)
eval_loader = DataLoader(
eval_dataset,
batch_size=batch_size,
collate_fn=test_collate_fn,
num_workers=num_workers,
)
return train_loader, valid_loader, eval_loader
def cli_main():
parser = ArgumentParser()
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
parser.add_argument("--data-dir", default=pathlib.Path("./Libri2Mix/wav8k/min"), type=pathlib.Path)
parser.add_argument(
"--librimix-tr-split",
default="train-360",
choices=["train-360", "train-100"],
help="The training partition of librimix dataset. (default: ``train-360``)",
)
parser.add_argument(
"--librimix-task",
default="sep_clean",
type=str,
choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
)
parser.add_argument(
"--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
)
parser.add_argument(
"--sample-rate",
default=8000,
type=int,
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--epochs",
metavar="NUM_EPOCHS",
default=200,
type=int,
help="The number of epochs to train. (default: 200)",
)
parser.add_argument(
"--learning-rate",
default=1e-3,
type=float,
help="Initial learning rate. (default: 1e-3)",
)
parser.add_argument(
"--num-gpu",
default=1,
type=int,
help="The number of GPUs for training. (default: 1)",
)
parser.add_argument(
"--num-workers",
default=4,
type=int,
help="The number of workers for dataloader. (default: 4)",
)
args = parser.parse_args()
model = _get_model(num_sources=args.num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.data_dir,
args.num_speakers,
args.sample_rate,
args.batch_size,
args.num_workers,
args.librimix_task,
args.librimix_tr_split,
)
loss = si_sdr_loss
metric_dict = {
"sdri": sdri_metric,
"sisdri": sisdri_metric,
}
model = ConvTasNetModule(
model=model,
train_loader=train_loader,
val_loader=valid_loader,
loss=loss,
optim=optimizer,
metrics=metric_dict,
lr_scheduler=lr_scheduler,
)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True
)
callbacks = [
checkpoint,
EarlyStopping(monitor="Losses/val_loss", mode="min", patience=30, verbose=True),
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
gpus=args.num_gpu,
accelerator="ddp",
plugins=DDPPlugin(find_unused_parameters=False), # make sure there is no unused params
limit_train_batches=1.0, # Useful for fast experiment
gradient_clip_val=5.0,
callbacks=callbacks,
)
trainer.fit(model)
model.load_from_checkpoint(checkpoint.best_model_path)
state_dict = torch.load(checkpoint.best_model_path, map_location="cpu")
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
torch.save(state_dict, args.exp_dir / "best_model.pth")
trainer.test(model, eval_loader)
if __name__ == "__main__":
cli_main()
from . import utils, wsj0mix from . import utils, wsj0mix, librimix
__all__ = ['utils', 'wsj0mix'] __all__ = ['utils', 'wsj0mix', 'librimix']
from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
class LibriMix(Dataset):
r"""Create the LibriMix dataset.
Args:
root (str or Path): the path to the directory where the dataset is stored.
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``. (Default: 8000)
task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``)
"""
def __init__(
self,
root: Union[str, Path],
num_speakers: int = 2,
sample_rate: int = 8000,
task: str = "sep_clean",
):
self.root = Path(root)
self.sample_rate = sample_rate
self.task = task
self.mix_dir = (self.root / "mix_{}".format(task.split('_')[1])).resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*wav")]
self.files.sort()
def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}, "
f"but the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src)
return self.sample_rate, mixed, srcs
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return self._load_sample(self.files[key])
...@@ -4,30 +4,36 @@ from collections import namedtuple ...@@ -4,30 +4,36 @@ from collections import namedtuple
import torch import torch
from . import wsj0mix from . import wsj0mix, librimix
Batch = namedtuple("Batch", ["mix", "src", "mask"]) Batch = namedtuple("Batch", ["mix", "src", "mask"])
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate): def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None):
if dataset_type == "wsj0mix": if dataset_type == "wsj0mix":
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate) train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate) validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate) evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
elif dataset_type == "librimix":
train = librimix.LibriMix(root_dir / librimix_tr_split, num_speakers, sample_rate, task)
validation = librimix.LibriMix(root_dir / "dev", num_speakers, sample_rate, task)
evaluation = librimix.LibriMix(root_dir / "test", num_speakers, sample_rate, task)
else: else:
raise ValueError(f"Unexpected dataset: {dataset_type}") raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation return train, validation, evaluation
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False): def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False):
"""Ensure waveform has exact number of frames by slicing or padding""" """Ensure waveform has exact number of frames by slicing or padding"""
mix = sample[1] # [1, num_frames] mix = sample[1] # [1, time]
src = torch.cat(sample[2], 0) # [num_sources, num_frames] src = torch.cat(sample[2], 0) # [num_sources, time]
num_channels, num_frames = src.shape num_channels, num_frames = src.shape
num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor')
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor')
if num_frames >= target_num_frames: if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames: if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_frames - target_num_frames, [1]) start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
mix = mix[:, start_frame:] mix = mix[:, start_frame:]
src = src[:, start_frame:] src = src[:, start_frame:]
mix = mix[:, :target_num_frames] mix = mix[:, :target_num_frames]
...@@ -48,7 +54,7 @@ def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, dur ...@@ -48,7 +54,7 @@ def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, dur
mixes, srcs, masks = [], [], [] mixes, srcs, masks = [], [], []
for sample in samples: for sample in samples:
mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True) mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
mixes.append(mix) mixes.append(mix)
srcs.append(src) srcs.append(src)
...@@ -57,12 +63,12 @@ def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, dur ...@@ -57,12 +63,12 @@ def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, dur
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]): def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
max_num_frames = max(s[1].shape[-1] for s in samples) max_num_frames = max(s[1].shape[-1] for s in samples)
mixes, srcs, masks = [], [], [] mixes, srcs, masks = [], [], []
for sample in samples: for sample in samples:
mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False) mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False)
mixes.append(mix) mixes.append(mix)
srcs.append(src) srcs.append(src)
...@@ -73,10 +79,10 @@ def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]): ...@@ -73,10 +79,10 @@ def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]):
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4): def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"] assert mode in ["train", "test"]
if dataset_type == "wsj0mix": if dataset_type in ["wsj0mix", "librimix"]:
if mode == 'train': if mode == 'train':
if sample_rate is None: if sample_rate is None:
raise ValueError("sample_rate is not given.") raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration) return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return collate_fn_wsj0mix_test return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate)
raise ValueError(f"Unexpected dataset: {dataset_type}") raise ValueError(f"Unexpected dataset: {dataset_type}")
...@@ -40,8 +40,8 @@ class WSJ0Mix(Dataset): ...@@ -40,8 +40,8 @@ class WSJ0Mix(Dataset):
waveform, sample_rate = torchaudio.load(path) waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate: if sample_rate != self.sample_rate:
raise ValueError( raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}. " f"The dataset contains audio file of sample rate {sample_rate}, "
"Where the requested sample rate is {self.sample_rate}." f"but the requested sample rate is {self.sample_rate}."
) )
return waveform return waveform
......
...@@ -201,4 +201,4 @@ def sdri( ...@@ -201,4 +201,4 @@ def sdri(
""" """
sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ] sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ]
base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker] base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker]
return (sdr_.unsqueeze(1) - base_sdr).mean(dim=1) return sdr_ - base_sdr.mean(dim=1)
...@@ -88,6 +88,7 @@ class MaskGenerator(torch.nn.Module): ...@@ -88,6 +88,7 @@ class MaskGenerator(torch.nn.Module):
num_hidden (int): Intermediate feature dimention of conv blocks, <H> num_hidden (int): Intermediate feature dimention of conv blocks, <H>
num_layers (int): The number of conv blocks in one stack, <X>. num_layers (int): The number of conv blocks in one stack, <X>.
num_stacks (int): The number of conv block stacks, <R>. num_stacks (int): The number of conv block stacks, <R>.
msk_activate (str): The activation function of the mask output.
Note: Note:
This implementation corresponds to the "non-causal" setting in the paper. This implementation corresponds to the "non-causal" setting in the paper.
...@@ -102,6 +103,7 @@ class MaskGenerator(torch.nn.Module): ...@@ -102,6 +103,7 @@ class MaskGenerator(torch.nn.Module):
num_hidden: int, num_hidden: int,
num_layers: int, num_layers: int,
num_stacks: int, num_stacks: int,
msk_activate: str,
): ):
super().__init__() super().__init__()
...@@ -138,6 +140,12 @@ class MaskGenerator(torch.nn.Module): ...@@ -138,6 +140,12 @@ class MaskGenerator(torch.nn.Module):
self.output_conv = torch.nn.Conv1d( self.output_conv = torch.nn.Conv1d(
in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1, in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1,
) )
if msk_activate == "sigmoid":
self.mask_activate = torch.nn.Sigmoid()
elif msk_activate == "relu":
self.mask_activate = torch.nn.ReLU()
else:
raise ValueError(f"Unsupported activation {msk_activate}")
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Generate separation mask. """Generate separation mask.
...@@ -159,7 +167,7 @@ class MaskGenerator(torch.nn.Module): ...@@ -159,7 +167,7 @@ class MaskGenerator(torch.nn.Module):
output = output + skip output = output + skip
output = self.output_prelu(output) output = self.output_prelu(output)
output = self.output_conv(output) output = self.output_conv(output)
output = torch.sigmoid(output) output = self.mask_activate(output)
return output.view(batch_size, self.num_sources, self.input_dim, -1) return output.view(batch_size, self.num_sources, self.input_dim, -1)
...@@ -177,6 +185,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -177,6 +185,7 @@ class ConvTasNet(torch.nn.Module):
msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>. msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>. msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>. msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
Note: Note:
This implementation corresponds to the "non-causal" setting in the paper. This implementation corresponds to the "non-causal" setting in the paper.
...@@ -194,6 +203,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -194,6 +203,7 @@ class ConvTasNet(torch.nn.Module):
msk_num_hidden_feats: int = 512, msk_num_hidden_feats: int = 512,
msk_num_layers: int = 8, msk_num_layers: int = 8,
msk_num_stacks: int = 3, msk_num_stacks: int = 3,
msk_activate: str = "sigmoid",
): ):
super().__init__() super().__init__()
...@@ -218,6 +228,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -218,6 +228,7 @@ class ConvTasNet(torch.nn.Module):
num_hidden=msk_num_hidden_feats, num_hidden=msk_num_hidden_feats,
num_layers=msk_num_layers, num_layers=msk_num_layers,
num_stacks=msk_num_stacks, num_stacks=msk_num_stacks,
msk_activate=msk_activate,
) )
self.decoder = torch.nn.ConvTranspose1d( self.decoder = torch.nn.ConvTranspose1d(
in_channels=enc_num_feats, in_channels=enc_num_feats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment