Unverified Commit fac1bba9 authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Add wavernn example pipeline (#749)

* Add WaveRNN example

This is the pipeline example based on [WaveRNN model](https://github.com/pytorch/audio/pull/735) in torchaudio. The design of this pipeline is inspired by [#632](https://github.com/pytorch/audio/pull/632). It offers a standardized implementation of WaveRNN vocoder in torchaudio.

* Add utils and readme

The metric logger is added based on the Wav2letter pipeline [#632](https://github.com/pytorch/audio/pull/632). It offers the way to parse the standard output as described in readme.

* Add channel dimension

The channel dimension of waveform in datasets is added to match the input dimensions of WaveRNN model because the channel dimensions of waveform and spectrogram are added in [this part] (https://github.com/pytorch/audio/blob/master/torchaudio/models/_wavernn.py#L281) of WaveRNN model.

* Update date split and transform

The design of dataset structure is discussed in [this comment](https://github.com/pytorch/audio/pull/749#discussion_r454627027

). Now the dataset file has a clearer workflow after using the random-split function instead of walking through all the files. All transform functions are put together inside the transforms block.
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent 2381dd89
This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was
introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio.
### Usage
An example can be invoked as follows.
```
python main.py \
--batch-size 256 \
--learning-rate 1e-4 \
--n-freq 80 \
--loss 'crossentropy' \
--n-bits 8 \
```
### Output
The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file.
```python
def read_json(filename):
"""
Convert the standard output saved to filename into a pandas dataframe for analysis.
"""
import pandas
import json
with open(filename, "r") as f:
data = f.read()
# pandas doesn't read single quotes for json
data = data.replace("'", '"')
data = [json.loads(l) for l in data.splitlines()]
return pandas.DataFrame(data)
```
import os
import random
import torch
import torchaudio
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH
from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]
item = self.dataset[n]
self._cache[n] = item
return item
def __len__(self):
return len(self.dataset)
class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms):
self.dataset = dataset
self.transforms = transforms
def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)
def __len__(self):
return len(self.dataset)
def process_datapoint(self, item):
specgram = self.transforms(item[0])
return item[0].squeeze(0), specgram
def split_process_ljspeech(args, transforms):
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)
train_dataset = MapMemoryCache(train_dataset)
val_dataset = MapMemoryCache(val_dataset)
return train_dataset, val_dataset
def collate_factory(args):
def raw_collate(batch):
pad = (args.kernel_size - 1) // 2
# input waveform length
wave_length = args.hop_length * args.seq_len_factor
# input spectrogram length
spec_length = args.seq_len_factor + pad * 2
# max start postion in spectrogram
max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]
# random start postion in spectrogram
spec_offsets = [random.randint(0, offset) for offset in max_offsets]
# random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]
waveform_combine = [
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1]
for i, x in enumerate(batch)
]
specgram = [
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length]
for i, x in enumerate(batch)
]
specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine)
waveform = waveform_combine[:, :wave_length]
target = waveform_combine[:, 1:]
# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
if args.loss == "crossentropy":
if args.mulaw:
mulaw_encode = MuLawEncoding(2 ** args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)
waveform = bits_to_normalized_waveform(waveform, args.n_bits)
else:
target = normalized_waveform_to_bits(target, args.n_bits)
return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)
return raw_collate
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
class LongCrossEntropyLoss(nn.Module):
r""" CrossEntropy loss
"""
def __init__(self):
super(LongCrossEntropyLoss, self).__init__()
def forward(self, output, target):
output = output.transpose(1, 2)
target = target.long()
criterion = nn.CrossEntropyLoss()
return criterion(output, target)
class MoLLoss(nn.Module):
r""" Discretized mixture of logistic distributions loss
Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)
Args:
y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
y (Tensor): Target (n_batch x n_time x 1)
num_classes (int): Number of classes
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each minibatch
Returns
Tensor: loss
"""
def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
super(MoLLoss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce
def forward(self, y_hat, y):
y = y.unsqueeze(-1)
if self.log_scale_min is None:
self.log_scale_min = math.log(1e-14)
assert y_hat.dim() == 3
assert y_hat.size(-1) % 3 == 0
nr_mix = y_hat.size(-1) // 3
# unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix: 2 * nr_mix]
log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
)
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means)
centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
cdf_min = torch.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(
torch.clamp(cdf_delta, min=1e-12)
) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = (
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
)
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if self.reduce:
return -torch.mean(_log_sum_exp(log_probs))
else:
return -_log_sum_exp(log_probs).unsqueeze(-1)
def _log_sum_exp(x):
r""" Numerically stable log_sum_exp implementation that prevents overflow
"""
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
import argparse
import logging
import os
import signal
from collections import defaultdict
from datetime import datetime
from time import time
from typing import List
import torch
import torchaudio
from torch import nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models._wavernn import _WaveRNN
from datasets import collate_factory, split_process_ljspeech
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--workers",
default=4,
type=int,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--checkpoint",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint",
)
parser.add_argument(
"--epochs",
default=8000,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument(
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate",
)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
parser.add_argument(
"--mulaw",
default=True,
action="store_true",
help="if used, waveform is mulaw encoded",
)
parser.add_argument(
"--jit", default=False, action="store_true", help="if used, model is jitted"
)
parser.add_argument(
"--upsample-scales",
default=[5, 5, 11],
type=List[int],
help="the list of upsample scales",
)
parser.add_argument(
"--n-bits", default=8, type=int, help="the bits of output waveform",
)
parser.add_argument(
"--sample-rate",
default=22050,
type=int,
help="the rate of audio dimensions (samples per second)",
)
parser.add_argument(
"--hop-length",
default=275,
type=int,
help="the number of samples between the starts of consecutive frames",
)
parser.add_argument(
"--win-length", default=1100, type=int, help="the length of the STFT window",
)
parser.add_argument(
"--f-min", default=40.0, type=float, help="the minimum frequency",
)
parser.add_argument(
"--min-level-db",
default=-100,
type=float,
help="the minimum db value for spectrogam normalization",
)
parser.add_argument(
"--n-res-block", default=10, type=int, help="the number of ResBlock in stack",
)
parser.add_argument(
"--n-rnn", default=512, type=int, help="the dimension of RNN layer",
)
parser.add_argument(
"--n-fc", default=512, type=int, help="the dimension of fully connected layer",
)
parser.add_argument(
"--kernel-size",
default=5,
type=int,
help="the number of kernel size in the first Conv1d layer",
)
parser.add_argument(
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use",
)
parser.add_argument(
"--n-hidden-melresnet",
default=128,
type=int,
help="the number of hidden dimensions of resblock in melresnet",
)
parser.add_argument(
"--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet",
)
parser.add_argument(
"--n-fft", default=2048, type=int, help="the number of Fourier bins",
)
parser.add_argument(
"--loss",
default="crossentropy",
choices=["crossentropy", "mol"],
type=str,
help="the type of loss",
)
parser.add_argument(
"--seq-len-factor",
default=5,
type=int,
help="the length of each waveform to process per batch = hop_length * seq_len_factor",
)
parser.add_argument(
"--val-ratio",
default=0.1,
type=float,
help="the ratio of waveforms for validation",
)
parser.add_argument(
"--file-path", default="", type=str, help="the path of audio files",
)
args = parser.parse_args()
return args
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
model.train()
sums = defaultdict(lambda: 0.0)
start1 = time()
metric = MetricLogger("train_iteration")
metric["epoch"] = epoch
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
start2 = time()
waveform = waveform.to(device)
specgram = specgram.to(device)
target = target.to(device)
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)
loss = criterion(output, target)
loss_item = loss.item()
sums["loss"] += loss_item
metric["loss"] = loss_item
optimizer.zero_grad()
loss.backward()
if args.clip_grad > 0:
gradient = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.clip_grad
)
sums["gradient"] += gradient.item()
metric["gradient"] = gradient.item()
optimizer.step()
metric["iteration"] = sums["iteration"]
metric["time"] = time() - start2
metric()
sums["iteration"] += 1
avg_loss = sums["loss"] / len(data_loader)
metric = MetricLogger("train_epoch")
metric["epoch"] = epoch
metric["loss"] = sums["loss"] / len(data_loader)
metric["gradient"] = avg_loss
metric["time"] = time() - start1
metric()
def validate(model, criterion, data_loader, device, epoch):
with torch.no_grad():
model.eval()
sums = defaultdict(lambda: 0.0)
start = time()
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
waveform = waveform.to(device)
specgram = specgram.to(device)
target = target.to(device)
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)
loss = criterion(output, target)
sums["loss"] += loss.item()
avg_loss = sums["loss"] / len(data_loader)
metric = MetricLogger("validation")
metric["epoch"] = epoch
metric["loss"] = avg_loss
metric["time"] = time() - start
metric()
return avg_loss
def main(args):
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
logging.info("Start time: {}".format(str(datetime.now())))
melkwargs = {
"n_fft": args.n_fft,
"power": 1,
"hop_length": args.hop_length,
"win_length": args.win_length,
}
transforms = torch.nn.Sequential(
torchaudio.transforms.Spectrogram(**melkwargs),
LinearToMel(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
n_mels=args.n_freq,
fmin=args.f_min,
),
NormalizeDB(min_level_db=args.min_level_db),
)
train_dataset, val_dataset = split_process_ljspeech(args, transforms)
loader_training_params = {
"num_workers": args.workers,
"pin_memory": False,
"shuffle": True,
"drop_last": False,
}
loader_validation_params = loader_training_params.copy()
loader_validation_params["shuffle"] = False
collate_fn = collate_factory(args)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
**loader_training_params,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
**loader_validation_params,
)
n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30
model = _WaveRNN(
upsample_scales=args.upsample_scales,
n_classes=n_classes,
hop_length=args.hop_length,
n_res_block=args.n_res_block,
n_rnn=args.n_rnn,
n_fc=args.n_fc,
kernel_size=args.kernel_size,
n_freq=args.n_freq,
n_hidden=args.n_hidden_melresnet,
n_output=args.n_output_melresnet,
)
if args.jit:
model = torch.jit.script(model)
model = torch.nn.DataParallel(model)
model = model.to(devices[0], non_blocking=True)
n = count_parameters(model)
logging.info(f"Number of parameters: {n}")
# Optimizer
optimizer_params = {
"lr": args.learning_rate,
}
optimizer = Adam(model.parameters(), **optimizer_params)
criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss()
best_loss = 10.0
if args.checkpoint and os.path.isfile(args.checkpoint):
logging.info(f"Checkpoint: loading '{args.checkpoint}'")
checkpoint = torch.load(args.checkpoint)
args.start_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(
f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}"
)
else:
logging.info("Checkpoint: not found")
save_checkpoint(
{
"epoch": args.start_epoch,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
False,
args.checkpoint,
)
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(
model, criterion, optimizer, train_loader, devices[0], epoch,
)
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
sum_loss = validate(model, criterion, val_loader, devices[0], epoch)
is_best = sum_loss < best_loss
best_loss = min(sum_loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
is_best,
args.checkpoint,
)
logging.info(f"End time: {datetime.now()}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parse_args()
main(args)
import librosa
import torch
import torch.nn as nn
# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved
class LinearToMel(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.n_mels = n_mels
self.fmin = fmin
self.htk = htk
self.norm = norm
def forward(self, specgram):
specgram = librosa.feature.melspectrogram(
S=specgram.squeeze(0).numpy(),
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
htk=self.htk,
norm=self.norm,
)
return torch.from_numpy(specgram)
class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value
"""
def __init__(self, min_level_db):
super().__init__()
self.min_level_db = min_level_db
def forward(self, specgram):
specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5))
return torch.clamp(
(self.min_level_db - specgram) / self.min_level_db, min=0, max=1
)
def normalized_waveform_to_bits(waveform, bits):
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]
"""
assert abs(waveform).max() <= 1.0
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
return torch.clamp(waveform, 0, 2 ** bits - 1).int()
def bits_to_normalized_waveform(label, bits):
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]
"""
return 2 * label / (2 ** bits - 1.0) - 1.0
import logging
import os
import shutil
from collections import defaultdict, deque
import torch
class MetricLogger:
r"""Logger for model metrics
"""
def __init__(self, group, print_freq=1):
self.print_freq = print_freq
self._iter = 0
self.data = defaultdict(lambda: deque(maxlen=self.print_freq))
self.data["group"].append(group)
def __setitem__(self, key, value):
self.data[key].append(value)
def _get_last(self):
return {k: v[-1] for k, v in self.data.items()}
def __str__(self):
return str(self._get_last())
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename):
r"""Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.info("Checkpoint: saved")
def count_parameters(model):
r"""Count the total number of parameters in the model
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment