"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "6ac50e26413465fe0418cd2326af89e4101cc099"
Unverified Commit 4e97213b authored by moto's avatar moto Committed by GitHub
Browse files

Add Conv-TasNet training script to example (#896)

parent 2c07658b
# Source Separation Example
This directory contains reference implementations for source separations. For the detail of each model, please checkout the followings.
- [Conv-TasNet](./conv_tasnet/README.md)
## Usage
### Overview
To training a model, you can use [`train.py`](./train.py). This script takes the form of
`train.py [parameters for distributed training] -- [parameters for model/training]`
```
python train.py \
[--worker-id WORKER_ID] \
[--device-id DEVICE_ID] \
[--num-workers NUM_WORKERS] \
[--sync-protocol SYNC_PROTOCOL] \
-- \
<model specific training parameters>
# For the detail of the parameter values, use;
python train.py --help
# For the detail of the model parameters, use;
python train.py -- --help
```
If you would like to just try out the traing script, then try it without any parameters
for distributed training. `train.py -- --sample-rate 8000 --batch-size <BATCH_SIZE> --dataset-dir <DATASET_DIR> --save-dir <SAVE_DIR>`
This script runs training in Distributed Data Parallel (DDP) framework and has two major
operation modes. This behavior depends on if `--worker-id` argument is given or not.
1. (`--worker-id` is not given) Launchs training worker subprocesses that performs the actual training.
2. (`--worker-id` is given) Performs the training as a part of distributed training.
When launching the script without any distributed trainig parameters (operation mode 1),
this script will check the number of GPUs available on the local system and spawns the same
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
`--num-workers`. If there is no GPU available, only one subprocess is launched and providing
`--num-workers` larger than 1 results in error.
When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers.
- `--num-workers` is the number of training processes being launched.
- `--worker-id` is the process rank (must be unique across all the processes).
- `--device-id` is the GPU device ID (should be unique within node).
- `--sync-protocl` is how each worker process communicate and synchronize.
If the training is carried out on a single node, then the default `"env://"` should do.
If the training processes span across multiple nodes, then you need to provide a protocol that
can communicate over the network. If you know where the master node is located, you can use
`"env://"` in combination with `MASTER_ADDR` and `MASER_PORT` environment variables. If you do
not know where the master node is located beforehand, you can use `"file://..."` protocol to
indicate where the file to which all the worker process have access is located. For other
protocols, please refer to the official documentation.
### Distributed Training Notes
<details><summary>Quick overview on DDP (distributed data parallel)</summary>
DDP is single-program multiple-data training paradigm.
With DDP, the model is replicated on every process,
and every model replica will be fed with a different set of input data samples.
- **Process**: Worker process (as in Linux process). There are `P` processes per a Node.
- **Node**: A machine. There are `N` machines, each of which holds `P` processes.
- **World**: network of nodes, composed of `N` nodes and `N * P` processes.
- **Rank**: Grobal process ID (unique across nodes) `[0, N * P)`
- **Local Rank**: Local process ID (unique only within a node) `[0, P)`
```
Node 0 Node 1 Node N-1
┌────────────────────────┐┌─────────────────────────┐ ┌───────────────────────────┐
│╔══════════╗ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
│║ Process ╟─┤ GPU: 0 ││││ Process ├─┤ GPU: 0 ││ ││ Process ├─┤ GPU: 0 ││
│║ Rank: 0 ║ └─────────┘│││ Rank:P │ └─────────┘│ ││ Rank:NP-P │ └─────────┘│
│╚══════════╝ ││└───────────┘ │ │└─────────────┘ │
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
││ Process ├─┤ GPU: 1 ││││ Process ├─┤ GPU: 1 ││ ││ Process ├─┤ GPU: 1 ││
││ Rank: 1 │ └─────────┘│││ Rank:P+1 │ └─────────┘│ ││ Rank:NP-P+1 │ └─────────┘│
│└──────────┘ ││└───────────┘ │ ... │└─────────────┘ │
│ ││ │ │ │
│ ... ││ ... │ │ ... │
│ ││ │ │ │
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
││ Process ├─┤ GPU:P-1 ││││ Process ├─┤ GPU:P-1 ││ ││ Process ├─┤ GPU:P-1 ││
││ Rank:P-1 │ └─────────┘│││ Rank:2P-1 │ └─────────┘│ ││ Rank:NP-1 │ └─────────┘│
│└──────────┘ ││└───────────┘ │ │└─────────────┘ │
└────────────────────────┘└─────────────────────────┘ └───────────────────────────┘
```
</details>
### SLURM
When launched as SLURM job, the follwoing environment variables correspond to
- **SLURM_PROCID*: `--worker-id` (Rank)
- **SLURM_NTASKS** (or legacy **SLURM_NPPROCS**): the number of total processes (`--num-workers` == world size)
- **SLURM_LOCALID**: Local Rank (to be mapped with GPU index*)
* Even when GPU resource is allocated with `--gpus-per-task=1`, if there are muptiple
tasks allocated on the same node, (thus multiple GPUs of the node are allocated to the job)
each task can see all the GPUs allocated for the tasks. Therefore we need to use
SLURM_LOCALID to tell each task to which GPU it should be using.
<details><summary>Example scripts for running the training on SLURM cluster</summary>
- **launch_job.sh**
```bash
#!/bin/bash
#SBATCH --job-name=source_separation
#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out
#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=8
#SBATCH --mem-per-cpu=16G
#SBATCH --gpus-per-task=1
#srun env
srun wrapper.sh $@
```
- **wrapper.sh**
```bash
#!/bin/bash
num_speakers=2
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
save_dir="/checkpoint/${USER}/jobs/${SLURM_JOB_NAME}/${SLURM_JOB_ID}"
dataset_dir="/dataset/wsj0-mix/${num_speakers}speakers/wav8k/min"
if [ "${SLURM_JOB_NUM_NODES}" -gt 1 ]; then
protocol="file:///checkpoint/${USER}/jobs/source_separation/${SLURM_JOB_ID}/sync"
else
protocol="env://"
fi
mkdir -p "${save_dir}"
python -u \
"${this_dir}/train.py" \
--worker-id "${SLURM_PROCID}" \
--num-workers "${SLURM_NTASKS}" \
--device-id "${SLURM_LOCALID}" \
--sync-protocol "${protocol}" \
-- \
--num-speakers "${num_speakers}" \
--sample-rate 8000 \
--dataset-dir "${dataset_dir}" \
--save-dir "${save_dir}" \
--batch-size $((16 / SLURM_NTASKS))
```
</details>
# Conv-TasNet
This is a reference implementation of Conv-TasNet.
> Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation." IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.8 (2019): 1256-1266. Crossref. Web.
This implementation is based on [arXiv:1809.07454v3](https://arxiv.org/abs/1809.07454v3) and [the reference implementation](https://github.com/naplab/Conv-TasNet) provided by the authors.
For the usage, please checkout the [source separation README](../README.md).
## (Default) Training Configurations
The default training/model configurations follow the best non-causal implementation from the paper. (causal configuration is not implemented.)
- Sample rate: 8000 Hz
- Batch size: total 16 over distributed training workers
- Epochs: 100
- Initial learning rate: 1e-3
- Gradient clipping: maximum L2 norm of 5.0
- Optimizer: Adam
- Learning rate scheduling: Halved after 3 epochs of no improvement in validation accuracy.
- Objective function: SI-SNRi
- Reported metrics: SI-SNRi, SDRi
- Sample audio length: 4 seconds (randomized position)
- Encoder/Decoder feature dimension (N): 512
- Encoder/Decoder convolution kernel size (L): 16
- TCN bottleneck/output feature dimension (B): 128
- TCN hidden feature dimension (H): 512
- TCN skip connection feature dimension (Sc): 128
- TCN convolution kernel size (P): 3
- The number of TCN convolution block layers (X): 8
- The number of TCN convolution blocks (R): 3
## Evaluation
The following is the evaluation result of training the model on WSJ0-2mix and WSJ0-3mix datasets.
### wsj0-mix 2speakers
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
|:------------------:|-------------:|----------:|------:|
| Reference | 15.3 | 15.6 | |
| Validation dataset | 13.1 | 13.1 | 100 |
| Evaluation dataset | 11.0 | 11.0 | 100 |
### wsj0-mix 3speakers
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
|:------------------:|-------------:|----------:|------:|
| Reference | 12.7 | 13.1 | |
| Validation dataset | 11.4 | 11.4 | 100 |
| Evaluation dataset | 8.9 | 8.9 | 100 |
#!/usr/bin/env python3
"""Train Conv-TasNet"""
import time
import pathlib
import argparse
import torch
import torchaudio
import torchaudio.models
import conv_tasnet
from utils import dist_utils
from utils.dataset import utils as dataset_utils
_LG = dist_utils.getLogger(__name__)
def _parse_args(args):
parser = argparse.ArgumentParser(description=__doc__,)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug behavior. Each epoch will end with just one batch.")
group = parser.add_argument_group("Model Options")
group.add_argument(
"--num-speakers", required=True, type=int, help="The number of speakers."
)
group = parser.add_argument_group("Dataset Options")
group.add_argument(
"--sample-rate",
required=True,
type=int,
help="Sample rate of audio files in the given dataset.",
)
group.add_argument(
"--dataset",
default="wsj0mix",
choices=["wsj0mix"],
help='Dataset type. (default: "wsj0mix")',
)
group.add_argument(
"--dataset-dir",
required=True,
type=pathlib.Path,
help=(
"Directory where dataset is found. "
'If the dataset type is "wsj9mix", then this is the directory where '
'"cv", "tt" and "tr" subdirectories are found.'
),
)
group = parser.add_argument_group("Save Options")
group.add_argument(
"--save-dir",
required=True,
type=pathlib.Path,
help=(
"Directory where the checkpoints and logs are saved. "
"Though, only the worker 0 saves checkpoint data, "
"all the worker processes must have access to the directory."
),
)
group = parser.add_argument_group("Dataloader Options")
group.add_argument(
"--batch-size",
type=int,
help=f"Batch size. (default: 16 // world_size)",
)
group = parser.add_argument_group("Training Options")
group.add_argument(
"--epochs",
metavar="NUM_EPOCHS",
default=100,
type=int,
help="The number of epochs to train. (default: 100)",
)
group.add_argument(
"--learning-rate",
default=1e-3,
type=float,
help="Initial learning rate. (default: 1e-3)",
)
group.add_argument(
"--grad-clip",
metavar="CLIP_VALUE",
default=5.0,
type=float,
help="Gradient clip value (l2 norm). (default: 5.0)",
)
group.add_argument(
"--resume",
metavar="CHECKPOINT_PATH",
help="Previous checkpoint file from which the training is resumed.",
)
args = parser.parse_args(args)
# Delaing the default value initialization until parse_args is done because
# if `--help` is given, distributed training is not enabled.
if args.batch_size is None:
args.batch_size = 16 // torch.distributed.get_world_size()
return args
def _get_model(
num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
):
model = torchaudio.models.ConvTasNet(
num_sources=num_sources,
enc_kernel_size=enc_kernel_size,
enc_num_feats=enc_num_feats,
msk_kernel_size=msk_kernel_size,
msk_num_feats=msk_num_feats,
msk_num_hidden_feats=msk_num_hidden_feats,
msk_num_layers=msk_num_layers,
msk_num_stacks=msk_num_stacks,
)
_LG.info_on_master("Model Configuration:")
_LG.info_on_master(" - N: %d", enc_num_feats)
_LG.info_on_master(" - L: %d", enc_kernel_size)
_LG.info_on_master(" - B: %d", msk_num_feats)
_LG.info_on_master(" - H: %d", msk_num_hidden_feats)
_LG.info_on_master(" - Sc: %d", msk_num_feats)
_LG.info_on_master(" - P: %d", msk_kernel_size)
_LG.info_on_master(" - X: %d", msk_num_layers)
_LG.info_on_master(" - R: %d", msk_num_stacks)
_LG.info_on_master(
" - Receptive Field: %s [samples]", model.mask_generator.receptive_field,
)
return model
def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_size):
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate,
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=4
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test')
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(train_dataset),
collate_fn=train_collate_fn,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(valid_dataset),
collate_fn=test_collate_fn,
pin_memory=True,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=batch_size,
sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset),
collate_fn=test_collate_fn,
pin_memory=True,
)
return train_loader, valid_loader, eval_loader
def _write_header(log_path, args):
rows = [
[f"# torch: {torch.__version__}", ],
[f"# torchaudio: {torchaudio.__version__}", ]
]
rows.append(["# arguments"])
for key, item in vars(args).items():
rows.append([f"# {key}: {item}"])
dist_utils.write_csv_on_master(log_path, *rows)
def train(args):
args = _parse_args(args)
_LG.info("%s", args)
args.save_dir.mkdir(parents=True, exist_ok=True)
if "sox_io" in torchaudio.list_audio_backends():
torchaudio.set_audio_backend("sox_io")
start_epoch = 1
if args.resume:
checkpoint = torch.load(args.resume)
if args.sample_rate != checkpoint["sample_rate"]:
raise ValueError(
"The provided sample rate ({args.sample_rate}) does not match "
"the sample rate from the check point ({checkpoint['sample_rate']})."
)
if args.num_speakers != checkpoint["num_speakers"]:
raise ValueError(
"The provided #of speakers ({args.num_speakers}) does not match "
"the #of speakers from the check point ({checkpoint['num_speakers']}.)"
)
start_epoch = checkpoint["epoch"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_LG.info("Using: %s", device)
model = _get_model(num_sources=args.num_speakers)
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device] if torch.cuda.is_available() else None
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
if args.resume:
_LG.info("Loading parameters from the checkpoint...")
model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
else:
dist_utils.synchronize_params(
str(args.save_dir / f"tmp.pt"), device, model, optimizer
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=3
)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.dataset_dir,
args.num_speakers,
args.sample_rate,
args.batch_size,
)
num_train_samples = len(train_loader.dataset)
num_valid_samples = len(valid_loader.dataset)
num_eval_samples = len(eval_loader.dataset)
_LG.info_on_master("Datasets:")
_LG.info_on_master(" - Train: %s", num_train_samples)
_LG.info_on_master(" - Valid: %s", num_valid_samples)
_LG.info_on_master(" - Eval: %s", num_eval_samples)
trainer = conv_tasnet.trainer.Trainer(
model,
optimizer,
train_loader,
valid_loader,
eval_loader,
args.grad_clip,
device,
debug=args.debug,
)
log_path = args.save_dir / f"log.csv"
_write_header(log_path, args)
dist_utils.write_csv_on_master(
log_path,
[
"epoch",
"learning_rate",
"valid_si_snri",
"valid_sdri",
"eval_si_snri",
"eval_sdri",
],
)
_LG.info_on_master("Running %s epochs", args.epochs)
for epoch in range(start_epoch, start_epoch + args.epochs):
_LG.info_on_master("=" * 70)
_LG.info_on_master("Epoch: %s", epoch)
_LG.info_on_master("Learning rate: %s", optimizer.param_groups[0]["lr"])
_LG.info_on_master("=" * 70)
t0 = time.monotonic()
trainer.train_one_epoch()
train_sps = num_train_samples / (time.monotonic() - t0)
_LG.info_on_master("-" * 70)
t0 = time.monotonic()
valid_metric = trainer.validate()
valid_sps = num_valid_samples / (time.monotonic() - t0)
_LG.info_on_master("Valid: %s", valid_metric)
_LG.info_on_master("-" * 70)
t0 = time.monotonic()
eval_metric = trainer.evaluate()
eval_sps = num_eval_samples / (time.monotonic() - t0)
_LG.info_on_master(" Eval: %s", eval_metric)
_LG.info_on_master("-" * 70)
_LG.info_on_master("Train: Speed: %6.2f [samples/sec]", train_sps)
_LG.info_on_master("Valid: Speed: %6.2f [samples/sec]", valid_sps)
_LG.info_on_master(" Eval: Speed: %6.2f [samples/sec]", eval_sps)
_LG.info_on_master("-" * 70)
dist_utils.write_csv_on_master(
log_path,
[
epoch,
optimizer.param_groups[0]["lr"],
valid_metric.si_snri,
valid_metric.sdri,
eval_metric.si_snri,
eval_metric.sdri,
],
)
lr_scheduler.step(valid_metric.si_snri)
save_path = args.save_dir / f"epoch_{epoch}.pt"
dist_utils.save_on_master(
save_path,
{
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"num_speakers": args.num_speakers,
"sample_rate": args.sample_rate,
"epoch": epoch,
},
)
import time
from typing import Tuple
from collections import namedtuple
import torch
import torch.distributed as dist
from utils import dist_utils, metrics
_LG = dist_utils.getLogger(__name__)
Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = (
lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
)
def si_sdr_improvement(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
Returns:
torch.Tensor: Improved SI-SDR. Shape: [batch, ]
torch.Tensor: Absolute SI-SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
sdri = metrics.sdri(estimate, reference, mix, mask=mask)
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
mix = mix - mix.mean(axis=2, keepdim=True)
si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return si_sdri, sdri
class OccasionalLogger:
"""Simple helper class to log once in a while or when progress is quick enough"""
def __init__(self, time_interval=180, progress_interval=0.1):
self.time_interval = time_interval
self.progress_interval = progress_interval
self.last_time = 0.0
self.last_progress = 0.0
def log(self, metric, progress, force=False):
now = time.monotonic()
if (
force
or now > self.last_time + self.time_interval
or progress > self.last_progress + self.progress_interval
):
self.last_time = now
self.last_progress = progress
_LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)
class Trainer:
def __init__(
self,
model,
optimizer,
train_loader,
valid_loader,
eval_loader,
grad_clip,
device,
*,
debug,
):
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.valid_loader = valid_loader
self.eval_loader = eval_loader
self.grad_clip = grad_clip
self.device = device
self.debug = debug
def train_one_epoch(self):
self.model.train()
logger = OccasionalLogger()
num_batches = len(self.train_loader)
for i, batch in enumerate(self.train_loader, start=1):
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
si_snri = si_snri.mean()
sdri = sdri.mean()
loss = -si_snri
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip, norm_type=2.0
)
self.optimizer.step()
metric = Metric(si_snri.item(), sdri.item())
logger.log(metric, progress=i / num_batches, force=i == num_batches)
if self.debug:
break
def evaluate(self):
with torch.no_grad():
return self._test(self.eval_loader)
def validate(self):
with torch.no_grad():
return self._test(self.valid_loader)
def _test(self, loader):
self.model.eval()
total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device)
total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device)
for batch in loader:
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
total_si_snri += si_snri.sum()
total_sdri += sdri.sum()
if self.debug:
break
dist.all_reduce(total_si_snri, dist.ReduceOp.SUM)
dist.all_reduce(total_sdri, dist.ReduceOp.SUM)
num_samples = len(loader.dataset)
metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples)
return metric
#!/usr/bin/env python3
"""Launch souce separation training.
This script runs training in Distributed Data Parallel (DDP) framework and has two major
operation modes. This behavior depends on if `--worker-id` argument is given or not.
1. (`--worker-id` is not given) Launchs worker subprocesses that performs the actual training.
2. (`--worker-id` is given) Performs the training as a part of distributed training.
When launching the script without any distributed trainig parameters (operation mode 1),
this script will check the number of GPUs available on the local system and spawns the same
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
`--num-workers`. If there is no GPU available, only one subprocess is launched.
When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers.
"""
import sys
import logging
import argparse
import subprocess
import torch
from utils import dist_utils
_LG = dist_utils.getLogger(__name__)
def _parse_args(args=None):
max_world_size = torch.cuda.device_count() or 1
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument("--debug", action="store_true", help="Enable debug log")
group = parser.add_argument_group("Distributed Training")
group.add_argument(
"--worker-id",
type=int,
help=(
"If not provided, the launched process serves as a master process of "
"single-node, multi-worker training and spawns the worker subprocesses. "
"If provided, the launched process serves as a worker process, which "
"performs the actual training. The valid value is [0, --num-workers)."
),
)
group.add_argument(
"--device-id",
type=int,
help="The CUDA device ID. Allowed only when --worker-id is provided.",
)
group.add_argument(
"--num-workers",
type=int,
default=max_world_size,
help=(
"The size of distributed trainig workers. "
"If launching a training as single-node, multi-worker training, "
"(i.e. --worker-id is not provided) then this value should not exceed "
"the number of available GPUs. "
"If launching the training process as a multi-node, multi-gpu training, "
"(i.e. --worker-id is provided) then the value has to match "
f"the number of workers across nodes. (default: {max_world_size})"
),
)
group.add_argument(
"--sync-protocol",
type=str,
default="env://",
help=(
"Synchronization protocol for distributed training. "
"This value is passed as `init_method` argument of "
"`torch.distributed.init_process_group` function."
'If you are using `"env://"`, you can additionally configure '
'environment variables "MASTER_ADDR" and "MASTER_PORT". '
'If you are using `"file://..."`, then the process has to have '
"the access to the designated file. "
"See the documentation for `torch.distributed` for the detail. "
'If you are running the training in a single node, `"env://"` '
"should do. If you are running the training in multiple nodes, "
"you need to provide the file location where all the nodes have "
'access, using `"file://..."` protocol. (default: "env://")'
),
)
group.add_argument(
"--random-seed",
type=int,
help="Set random seed value. (default: None)",
)
parser.add_argument(
"rest", nargs=argparse.REMAINDER, help="Model-specific arguments."
)
namespace = parser.parse_args(args)
if namespace.worker_id is None:
if namespace.device_id is not None:
raise ValueError(
"`--device-id` cannot be provided when runing as master process."
)
if namespace.num_workers > max_world_size:
raise ValueError(
"--num-workers ({num_workers}) cannot exceed {device_count}."
)
if namespace.rest[:1] == ["--"]:
namespace.rest = namespace.rest[1:]
return namespace
def _main(cli_args):
args = _parse_args(cli_args)
if any(arg in ["--help", "-h"] for arg in args.rest):
_run_training(args.rest)
_init_logger(args.worker_id, args.debug)
if args.worker_id is None:
_run_training_subprocesses(args.num_workers, cli_args)
else:
dist_utils.setup_distributed(
world_size=args.num_workers,
rank=args.worker_id,
local_rank=args.device_id,
backend='nccl' if torch.cuda.is_available() else 'gloo',
init_method=args.sync_protocol,
)
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
if torch.cuda.is_available():
torch.cuda.set_device(args.device_id)
_LG.info("CUDA device set to %s", args.device_id)
_run_training(args.rest)
def _run_training_subprocesses(num_workers, original_args):
workers = []
_LG.info("Spawning %s workers", num_workers)
for i in range(num_workers):
worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"]
device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else []
command = (
[sys.executable, "-u", sys.argv[0]]
+ worker_arg
+ device_arg
+ original_args
)
_LG.info("Launching worker %s: `%s`", i, " ".join(command))
worker = subprocess.Popen(command)
workers.append(worker)
num_failed = 0
for worker in workers:
worker.wait()
if worker.returncode != 0:
num_failed += 1
sys.exit(num_failed)
def _run_training(args):
import conv_tasnet.train
conv_tasnet.train.train(args)
def _init_logger(rank=None, debug=False):
worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]"
message_fmt = (
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {worker_fmt} {message_fmt}",
)
if __name__ == "__main__":
_main(sys.argv[1:])
from . import (
dataset,
dist_utils,
metrics,
)
import os
import csv
import types
import logging
import torch
import torch.distributed as dist
def _info_on_master(self, *args, **kwargs):
if dist.get_rank() == 0:
self.info(*args, **kwargs)
def getLogger(name):
"""Get logging.Logger module with additional ``info_on_master`` method."""
logger = logging.getLogger(name)
logger.info_on_master = types.MethodType(_info_on_master, logger)
return logger
_LG = getLogger(__name__)
def setup_distributed(
world_size, rank, local_rank, backend="nccl", init_method="env://"
):
"""Perform env setup and initialization for distributed training"""
if init_method == "env://":
_set_env_vars(world_size, rank, local_rank)
if world_size > 1 and "OMP_NUM_THREADS" not in os.environ:
_LG.info("Setting OMP_NUM_THREADS == 1")
os.environ["OMP_NUM_THREADS"] = "1"
params = {
"backend": backend,
"init_method": init_method,
"world_size": world_size,
"rank": rank,
}
_LG.info("Initializing distributed process group with %s", params)
dist.init_process_group(**params)
_LG.info("Initialized distributed process group.")
def _set_env_vars(world_size, rank, local_rank):
for key, default in [("MASTER_ADDR", "127.0.0.1"), ("MASTER_PORT", "29500")]:
if key not in os.environ:
os.environ[key] = default
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)
def save_on_master(path, obj):
if dist.get_rank() == 0:
_LG.info("Saving %s", path)
torch.save(obj, path)
def write_csv_on_master(path, *rows):
if dist.get_rank() == 0:
with open(path, "a", newline="") as fileobj:
writer = csv.writer(fileobj)
for row in rows:
writer.writerow(row)
def synchronize_params(path, device, *modules):
if dist.get_world_size() < 2:
return
rank = dist.get_rank()
if rank == 0:
_LG.info("[Parameter Sync]: Saving parameters to a temp file...")
torch.save({f"{i}": m.state_dict() for i, m in enumerate(modules)}, path)
dist.barrier()
if rank != 0:
_LG.info("[Parameter Sync]: Loading parameters...")
data = torch.load(path, map_location=device)
for i, m in enumerate(modules):
m.load_state_dict(data[f"{i}"])
dist.barrier()
if rank == 0:
_LG.info("[Parameter Sync]: Removing the temp file...")
os.remove(path)
_LG.info_on_master("[Parameter Sync]: Complete.")
import math
from typing import Optional
from itertools import permutations
import torch
def sdr(estimate: torch.Tensor, reference: torch.Tensor, epsilon=1e-8) -> torch.Tensor:
def sdr(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Computes source-to-distortion ratio.
1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
......@@ -15,6 +21,8 @@ def sdr(estimate: torch.Tensor, reference: torch.Tensor, epsilon=1e-8) -> torch.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
......@@ -39,8 +47,16 @@ def sdr(estimate: torch.Tensor, reference: torch.Tensor, epsilon=1e-8) -> torch.
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2).mean(axis=2)
error_pow = error.pow(2).mean(axis=2)
reference_pow = reference.pow(2)
error_pow = error.pow(2)
if mask is None:
reference_pow = reference_pow.mean(axis=2)
error_pow = error_pow.mean(axis=2)
else:
denom = mask.sum(axis=2)
reference_pow = (mask * reference_pow).sum(axis=2) / denom
error_pow = (mask * error_pow).sum(axis=2) / denom
return 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
......@@ -69,7 +85,13 @@ class PIT(torch.nn.Module):
super().__init__()
self.utility_func = utility_func
def forward(self, estimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
def forward(
self,
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Compute utterance-level PIT Loss
Args:
......@@ -77,6 +99,9 @@ class PIT(torch.nn.Module):
Shape: [bacth, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: Maximum criterion over the speaker permutation.
......@@ -91,7 +116,7 @@ class PIT(torch.nn.Module):
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :])
util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon)
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
return util_mat.max(dim=1).values
......@@ -99,7 +124,11 @@ class PIT(torch.nn.Module):
_sdr_pit = PIT(utility_func=sdr)
def sdr_pit(estimate, reference):
def sdr_pit(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
......@@ -111,6 +140,8 @@ def sdr_pit(estimate, reference):
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
......@@ -129,10 +160,16 @@ def sdr_pit(estimate, reference):
*when the inputs have 0-mean*
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153
"""
return _sdr_pit(estimate, reference)
return _sdr_pit(estimate, reference, mask, epsilon)
def sdri(estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor) -> torch.Tensor:
def sdri(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi).
This function compute how much SDR is improved if the estimation is changed from
......@@ -150,6 +187,9 @@ def sdri(estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor) ->
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: Improved SDR. Shape: [batch, ]
......@@ -159,6 +199,6 @@ def sdri(estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor) ->
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
sdr_ = sdr_pit(estimate, reference) # [batch, ]
base_sdr = sdr(mix, reference) # [batch, speaker]
sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ]
base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker]
return (sdr_.unsqueeze(1) - base_sdr).mean(dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment