Commit ee10550a authored by liugh5's avatar liugh5
Browse files

Initial commit

parents
Pipeline #790 canceled with stages
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm
from distutils.version import LooseVersion
from pytorch_wavelets import DWT1DForward
from .layers import (
Conv1d,
CausalConv1d,
ConvTranspose1d,
CausalConvTranspose1d,
ResidualBlock,
SourceModule,
)
from kantts.utils.audio_torch import stft
import copy
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
class Generator(torch.nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
channels=512,
kernel_size=7,
upsample_scales=(8, 8, 2, 2),
upsample_kernal_sizes=(16, 16, 4, 4),
resblock_kernel_sizes=(3, 7, 11),
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
repeat_upsample=True,
bias=True,
causal=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_weight_norm=True,
nsf_params=None,
):
super(Generator, self).__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernal size must be odd number."
assert len(upsample_scales) == len(upsample_kernal_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
self.upsample_scales = upsample_scales
self.repeat_upsample = repeat_upsample
self.num_upsamples = len(upsample_kernal_sizes)
self.num_kernels = len(resblock_kernel_sizes)
self.out_channels = out_channels
self.nsf_enable = nsf_params is not None
self.transpose_upsamples = torch.nn.ModuleList()
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
self.conv_blocks = torch.nn.ModuleList()
conv_cls = CausalConv1d if causal else Conv1d
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
self.conv_pre = conv_cls(
in_channels, channels, kernel_size, 1, padding=(kernel_size - 1) // 2
)
for i in range(len(upsample_kernal_sizes)):
self.transpose_upsamples.append(
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv_transposed_cls(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernal_sizes[i],
upsample_scales[i],
padding=(upsample_kernal_sizes[i] - upsample_scales[i]) // 2,
),
)
)
if repeat_upsample:
self.repeat_upsamples.append(
nn.Sequential(
nn.Upsample(mode="nearest", scale_factor=upsample_scales[i]),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv_cls(
channels // (2 ** i),
channels // (2 ** (i + 1)),
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
)
)
for j in range(len(resblock_kernel_sizes)):
self.conv_blocks.append(
ResidualBlock(
channels=channels // (2 ** (i + 1)),
kernel_size=resblock_kernel_sizes[j],
dilation=resblock_dilations[j],
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
causal=causal,
)
)
self.conv_post = conv_cls(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
if self.nsf_enable:
self.source_module = SourceModule(
nb_harmonics=nsf_params["nb_harmonics"],
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
sampling_rate=nsf_params["sampling_rate"],
)
self.source_downs = nn.ModuleList()
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
for i, u in enumerate(self.downsample_cum_rates[::-1]):
if u == 1:
self.source_downs.append(
Conv1d(1, channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
conv_cls(
1,
channels // (2 ** (i + 1)),
u * 2,
u,
padding=u // 2,
)
)
def forward(self, x):
if self.nsf_enable:
mel = x[:, :-2, :]
pitch = x[:, -2:-1, :]
uv = x[:, -1:, :]
excitation = self.source_module(pitch, uv)
else:
mel = x
x = self.conv_pre(mel)
for i in range(self.num_upsamples):
# FIXME: sin function here seems to be causing issues
x = torch.sin(x) + x
rep = self.repeat_upsamples[i](x)
# transconv
up = self.transpose_upsamples[i](x)
if self.nsf_enable:
# Downsampling the excitation signal
e = self.source_downs[i](excitation)
# augment inputs with the excitation
x = rep + e + up[:, :, : rep.shape[-1]]
else:
x = rep + up[:, :, : rep.shape[-1]]
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.conv_blocks[i * self.num_kernels + j](x)
else:
xs += self.conv_blocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for layer in self.transpose_upsamples:
layer[-1].remove_weight_norm()
for layer in self.repeat_upsamples:
layer[-1].remove_weight_norm()
for layer in self.conv_blocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()
if self.nsf_enable:
self.source_module.remove_weight_norm()
for layer in self.source_downs:
layer.remove_weight_norm()
class PeriodDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
period=3,
kernel_sizes=[5, 3],
channels=32,
downsample_scales=[3, 3, 3, 3, 1],
max_downsample_channels=1024,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_spectral_norm=False,
):
super(PeriodDiscriminator, self).__init__()
self.period = period
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList()
in_chs, out_chs = in_channels, channels
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
in_chs = out_chs
out_chs = min(out_chs * 4, max_downsample_channels)
self.conv_post = nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(
self,
periods=[2, 3, 5, 7, 11],
discriminator_params={
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_spectral_norm": False,
},
):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params["period"] = period
self.discriminators += [PeriodDiscriminator(**params)]
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class ScaleDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=[15, 41, 5, 3],
channels=128,
max_downsample_channels=1024,
max_groups=16,
bias=True,
downsample_scales=[2, 2, 4, 4, 1],
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_spectral_norm=False,
):
super(ScaleDiscriminator, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_channels,
channels,
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
in_chs = channels
out_chs = channels
groups = 4
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
in_chs = out_chs
out_chs = min(in_chs * 2, max_downsample_channels)
groups = min(groups * 4, max_groups)
out_chs = min(in_chs * 2, max_downsample_channels)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
self.conv_post = norm_f(
nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
)
)
def forward(self, x):
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(
self,
scales=3,
downsample_pooling="DWT",
# follow the official implementation setting
downsample_pooling_params={
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
discriminator_params={
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm=False,
):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
params["use_spectral_norm"] = True if i == 0 else False
self.discriminators += [ScaleDiscriminator(**params)]
if downsample_pooling == "DWT":
self.meanpools = nn.ModuleList(
[DWT1DForward(wave="db3", J=1), DWT1DForward(wave="db3", J=1)]
)
self.aux_convs = nn.ModuleList(
[
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
]
)
else:
self.meanpools = nn.ModuleList(
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
)
self.aux_convs = None
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
if i != 0:
if self.aux_convs is None:
y = self.meanpools[i - 1](y)
else:
yl, yh = self.meanpools[i - 1](y)
y = torch.cat([yl, yh[0]], dim=1)
y = self.aux_convs[i - 1](y)
y = F.leaky_relu(y, 0.1)
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class SpecDiscriminator(torch.nn.Module):
def __init__(
self,
channels=32,
init_kernel=15,
kernel_size=11,
stride=2,
use_spectral_norm=False,
fft_size=1024,
shift_size=120,
win_length=600,
window="hann_window",
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
):
super(SpecDiscriminator, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
# fft_size // 2 + 1
norm_f = weight_norm if not use_spectral_norm else spectral_norm
final_kernel = 5
post_conv_kernel = 3
blocks = 3 # TODO: remove hard code here
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
fft_size // 2 + 1,
channels,
(init_kernel, 1),
(1, 1),
padding=(init_kernel - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
for i in range(blocks):
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(final_kernel, 1),
(1, 1),
padding=(final_kernel - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
self.conv_post = norm_f(
nn.Conv2d(
channels,
1,
(post_conv_kernel, 1),
(1, 1),
padding=((post_conv_kernel - 1) // 2, 0),
)
)
self.register_buffer("window", getattr(torch, window)(win_length))
def forward(self, wav):
with torch.no_grad():
wav = torch.squeeze(wav, 1)
x_mag = stft(
wav, self.fft_size, self.shift_size, self.win_length, self.window
)
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = x.squeeze(-1)
return x, fmap
class MultiSpecDiscriminator(torch.nn.Module):
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
discriminator_params={
"channels": 15,
"init_kernel": 1,
"kernel_sizes": 11,
"stride": 2,
"use_spectral_norm": False,
"window": "hann_window",
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
):
super(MultiSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths):
params = copy.deepcopy(discriminator_params)
params["fft_size"] = fft_size
params["shift_size"] = hop_size
params["win_length"] = win_length
self.discriminators += [SpecDiscriminator(**params)]
def forward(self, y):
y_d = []
fmap = []
for i, d in enumerate(self.discriminators):
x, x_map = d(y)
y_d.append(x)
fmap.append(x_map)
return y_d, fmap
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal
from kantts.models.utils import init_weights
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class Conv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(Conv1d, self).__init__()
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
)
self.conv1d.apply(init_weights)
def forward(self, x):
x = self.conv1d(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class CausalConv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(CausalConv1d, self).__init__()
self.pad = (kernel_size - 1) * dilation
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
)
self.conv1d.apply(init_weights)
def forward(self, x): # bdt
x = F.pad(
x, (self.pad, 0, 0, 0, 0, 0), "constant"
) # described starting from the last dimension and moving forward.
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
x = self.conv1d(x)[:, :, : x.size(2)]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class ConvTranspose1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
super(ConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
output_padding=0,
)
)
self.deconv.apply(init_weights)
def forward(self, x):
return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
# FIXME: HACK to get shape right
class CausalConvTranspose1d(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
"""Initialize CausalConvTranspose1d module."""
super(CausalConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
)
)
self.stride = stride
self.deconv.apply(init_weights)
self.pad = kernel_size - stride
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return self.deconv(x)[:, :, : -self.pad]
# return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
class ResidualBlock(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3, 5),
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
causal=False,
):
super(ResidualBlock, self).__init__()
assert kernel_size % 2 == 1, "Kernal size must be odd number."
conv_cls = CausalConv1d if causal else Conv1d
self.convs1 = nn.ModuleList(
[
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
)
for i in range(len(dilation))
]
)
self.convs2 = nn.ModuleList(
[
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
]
)
self.activation = getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = self.activation(x)
xt = c1(xt)
xt = self.activation(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class SourceModule(torch.nn.Module):
def __init__(
self, nb_harmonics, upsample_ratio, sampling_rate, alpha=0.1, sigma=0.003
):
super(SourceModule, self).__init__()
self.nb_harmonics = nb_harmonics
self.upsample_ratio = upsample_ratio
self.sampling_rate = sampling_rate
self.alpha = alpha
self.sigma = sigma
self.ffn = nn.Sequential(
weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
nn.Tanh(),
)
def forward(self, pitch, uv):
"""
:param pitch: [B, 1, frame_len], Hz
:param uv: [B, 1, frame_len] vuv flag
:return: [B, 1, sample_len]
"""
with torch.no_grad():
pitch_samples = F.interpolate(
pitch, scale_factor=(self.upsample_ratio), mode="nearest"
)
uv_samples = F.interpolate(
uv, scale_factor=(self.upsample_ratio), mode="nearest"
)
F_mat = torch.zeros(
(pitch_samples.size(0), self.nb_harmonics + 1, pitch_samples.size(-1))
).to(pitch_samples.device)
for i in range(self.nb_harmonics + 1):
F_mat[:, i : i + 1, :] = pitch_samples * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(
sample_shape=(pitch.size(0), self.nb_harmonics + 1, 1)
).to(F_mat.device)
phase_vec[:, 0, :] = 0
n_dist = Normal(loc=0.0, scale=self.sigma)
noise = n_dist.sample(
sample_shape=(
pitch_samples.size(0),
self.nb_harmonics + 1,
pitch_samples.size(-1),
)
).to(F_mat.device)
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
e_unvoice = self.alpha / 3 / self.sigma * noise
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
return self.ffn(e)
def remove_weight_norm(self):
remove_weight_norm(self.ffn[0])
# Copyright 2020 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Pseudo QMF modules."""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, "The number of taps mush be even number."
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid="ignore"):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps)
)
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = (
2
* h_proto
* np.cos(
(2 * k + 1)
* (np.pi / (2 * subbands))
* (np.arange(taps + 1) - (taps / 2))
+ (-1) ** k * np.pi / 4
)
)
h_synthesis[k] = (
2
* h_proto
* np.cos(
(2 * k + 1)
* (np.pi / (2 * subbands))
* (np.arange(taps + 1) - (taps / 2))
- (-1) ** k * np.pi / 4
)
)
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer("analysis_filter", analysis_filter)
self.register_buffer("synthesis_filter", synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
# TODO(kan-bayashi): Understand the reconstruction procedure
x = F.conv_transpose1d(
x, self.updown_filter * self.subbands, stride=self.subbands
)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, temperature, dropatt=0.0):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
self.dropatt = nn.Dropout(dropatt)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropatt(attn)
output = torch.bmm(attn, v)
return output, attn
class Prenet(nn.Module):
def __init__(self, in_units, prenet_units, out_units=0):
super(Prenet, self).__init__()
self.fcs = nn.ModuleList()
for in_dim, out_dim in zip([in_units] + prenet_units[:-1], prenet_units):
self.fcs.append(nn.Linear(in_dim, out_dim))
self.fcs.append(nn.ReLU())
self.fcs.append(nn.Dropout(0.5))
if out_units:
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
def forward(self, input):
output = input
for layer in self.fcs:
output = layer(output)
return output
class MultiHeadSelfAttention(nn.Module):
""" Multi-Head SelfAttention module """
def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_in = d_in
self.d_model = d_model
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt
)
self.fc = nn.Linear(n_head * d_head, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input, mask=None):
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = input.size()
residual = input
x = self.layer_norm(input)
qkv = self.w_qkv(x)
q, k, v = qkv.chunk(3, -1)
q = q.view(sz_b, len_in, n_head, d_head)
k = k.view(sz_b, len_in, n_head, d_head)
v = v.view(sz_b, len_in, n_head, d_head)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
if mask is not None:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_in, d_head)
output = (
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output = self.dropout(self.fc(output))
if output.size(-1) == residual.size(-1):
output = output + residual
return output, attn
class PositionwiseConvFeedForward(nn.Module):
""" A two-feed-forward-layer module """
def __init__(self, d_in, d_hid, kernel_size=(3, 1), dropout_inner=0.1, dropout=0.1):
super().__init__()
# Use Conv1D
# position-wise
self.w_1 = nn.Conv1d(
d_in,
d_hid,
kernel_size=kernel_size[0],
padding=(kernel_size[0] - 1) // 2,
)
# position-wise
self.w_2 = nn.Conv1d(
d_hid,
d_in,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
)
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout_inner = nn.Dropout(dropout_inner)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
output = x.transpose(1, 2)
output = F.relu(self.w_1(output))
if mask is not None:
output = output.masked_fill(mask.unsqueeze(1), 0)
output = self.dropout_inner(output)
output = self.w_2(output)
output = output.transpose(1, 2)
output = self.dropout(output)
output = output + residual
return output
class FFTBlock(nn.Module):
"""FFT Block"""
def __init__(
self,
d_in,
d_model,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadSelfAttention(
n_head, d_in, d_model, d_head, dropout=dropout, dropatt=dropout_attn
)
self.pos_ffn = PositionwiseConvFeedForward(
d_model, d_inner, kernel_size, dropout_inner=dropout_relu, dropout=dropout
)
def forward(self, input, mask=None, slf_attn_mask=None):
output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, slf_attn
class MultiHeadPNCAAttention(nn.Module):
""" Multi-Head Attention PNCA module """
def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_model = d_model
self.d_mem = d_mem
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head)
self.fc_x = nn.Linear(n_head * d_head, d_model)
self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head)
self.fc_h = nn.Linear(n_head * d_head, d_model)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt
)
self.dropout = nn.Dropout(dropout)
def update_x_state(self, x):
d_head, n_head = self.d_head, self.n_head
sz_b, len_x, _ = x.size()
x_qkv = self.w_x_qkv(x)
x_q, x_k, x_v = x_qkv.chunk(3, -1)
x_q = x_q.view(sz_b, len_x, n_head, d_head)
x_k = x_k.view(sz_b, len_x, n_head, d_head)
x_v = x_v.view(sz_b, len_x, n_head, d_head)
x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
if self.x_state_size:
self.x_k = torch.cat([self.x_k, x_k], dim=1)
self.x_v = torch.cat([self.x_v, x_v], dim=1)
else:
self.x_k = x_k
self.x_v = x_v
self.x_state_size += len_x
return x_q, x_k, x_v
def update_h_state(self, h):
if self.h_state_size == h.size(1):
return None, None
d_head, n_head = self.d_head, self.n_head
# H
sz_b, len_h, _ = h.size()
h_kv = self.w_h_kv(h)
h_k, h_v = h_kv.chunk(2, -1)
h_k = h_k.view(sz_b, len_h, n_head, d_head)
h_v = h_v.view(sz_b, len_h, n_head, d_head)
self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_state_size += len_h
return h_k, h_v
def reset_state(self):
self.h_k = None
self.h_v = None
self.h_state_size = 0
self.x_k = None
self.x_v = None
self.x_state_size = 0
def forward(self, x, h, mask_x=None, mask_h=None):
residual = x
self.update_h_state(h)
x_q, x_k, x_v = self.update_x_state(self.layer_norm(x))
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = x.size()
# X
if mask_x is not None:
mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x ..
output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x)
output_x = output_x.view(n_head, sz_b, len_in, d_head)
output_x = (
output_x.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output_x = self.fc_x(output_x)
# H
if mask_h is not None:
mask_h = mask_h.repeat(n_head, 1, 1)
output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h)
output_h = output_h.view(n_head, sz_b, len_in, d_head)
output_h = (
output_h.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output_h = self.fc_h(output_h)
output = output_x + output_h
output = self.dropout(output)
output = output + residual
return output, attn_x, attn_h
class PNCABlock(nn.Module):
"""PNCA Block"""
def __init__(
self,
d_model,
d_mem,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(PNCABlock, self).__init__()
self.pnca_attn = MultiHeadPNCAAttention(
n_head, d_model, d_mem, d_head, dropout=dropout, dropatt=dropout_attn
)
self.pos_ffn = PositionwiseConvFeedForward(
d_model, d_inner, kernel_size, dropout_inner=dropout_relu, dropout=dropout
)
def forward(
self, input, memory, mask=None, pnca_x_attn_mask=None, pnca_h_attn_mask=None
):
output, pnca_attn_x, pnca_attn_h = self.pnca_attn(
input, memory, pnca_x_attn_mask, pnca_h_attn_mask
)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, pnca_attn_x, pnca_attn_h
def reset_state(self):
self.pnca_attn.reset_state()
import torch
import torch.nn as nn
import torch.nn.functional as F
from kantts.models.sambert.fsmn import FsmnEncoderV2
from kantts.models.sambert import Prenet
class LengthRegulator(nn.Module):
def __init__(self, r=1):
super(LengthRegulator, self).__init__()
self.r = r
def forward(self, inputs, durations, masks=None):
reps = (durations + 0.5).long()
output_lens = reps.sum(dim=1)
max_len = output_lens.max()
reps_cumsum = torch.cumsum(F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[
:, None, :
]
range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_)
mult = mult.float()
out = torch.matmul(mult, inputs)
if masks is not None:
out = out.masked_fill(masks.unsqueeze(-1), 0.0)
seq_len = out.size(1)
padding = self.r - int(seq_len) % self.r
if padding < self.r:
out = F.pad(out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
out = out.transpose(1, 2)
return out, output_lens
class VarRnnARPredictor(nn.Module):
def __init__(self, cond_units, prenet_units, rnn_units):
super(VarRnnARPredictor, self).__init__()
self.prenet = Prenet(1, prenet_units)
self.lstm = nn.LSTM(
prenet_units[-1] + cond_units,
rnn_units,
num_layers=2,
batch_first=True,
bidirectional=False,
)
self.fc = nn.Linear(rnn_units, 1)
def forward(self, inputs, cond, h=None, masks=None):
x = torch.cat([self.prenet(inputs), cond], dim=-1)
# The input can also be a packed variable length sequence,
# here we just omit it for simplicity due to the mask and uni-directional lstm.
x, h_new = self.lstm(x, h)
x = self.fc(x).squeeze(-1)
x = F.relu(x)
if masks is not None:
x = x.masked_fill(masks, 0.0)
return x, h_new
def infer(self, cond, masks=None):
batch_size, length = cond.size(0), cond.size(1)
output = []
x = torch.zeros((batch_size, 1)).to(cond.device)
h = None
for i in range(length):
x, h = self.forward(x.unsqueeze(1), cond[:, i : i + 1, :], h=h)
output.append(x)
output = torch.cat(output, dim=-1)
if masks is not None:
output = output.masked_fill(masks, 0.0)
return output
class VarFsmnRnnNARPredictor(nn.Module):
def __init__(
self,
in_dim,
filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
lstm_units,
):
super(VarFsmnRnnNARPredictor, self).__init__()
self.fsmn = FsmnEncoderV2(
filter_size,
fsmn_num_layers,
in_dim,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
)
self.blstm = nn.LSTM(
num_memory_units,
lstm_units,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.fc = nn.Linear(2 * lstm_units, 1)
def forward(self, inputs, masks=None):
input_lengths = None
if masks is not None:
input_lengths = torch.sum((~masks).float(), dim=1).long()
x = self.fsmn(inputs, masks)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths.tolist(), batch_first=True, enforce_sorted=False
)
x, _ = self.blstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True, total_length=inputs.size(1)
)
else:
x, _ = self.blstm(x)
x = self.fc(x).squeeze(-1)
if masks is not None:
x = x.masked_fill(masks, 0.0)
return x
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def mas(attn_map, width=1):
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_j = np.arange(max(0, j - width), j + 1)
prev_log = np.array([log_p[i - 1, prev_idx] for prev_idx in prev_j])
ind = np.argmax(prev_log)
log_p[i, j] = attn_map[i, j] + prev_log[ind]
prev_ind[i, j] = prev_j[ind]
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True)
def mas_width1(attn_map):
"""mas with hardcoded width=1"""
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_log = log_p[i - 1, j]
prev_j = j
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
prev_log = log_p[i - 1, j - 1]
prev_j = j - 1
log_p[i, j] = attn_map[i, j] + prev_log
prev_ind[i, j] = prev_j
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True, parallel=True)
def b_mas(b_attn_map, in_lens, out_lens, width=1):
assert width == 1
attn_out = np.zeros_like(b_attn_map)
for b in nb.prange(b_attn_map.shape[0]):
out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]])
attn_out[b, 0, : out_lens[b], : in_lens[b]] = out
return attn_out
import numpy as np
import torch
from torch import nn
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class ConvAttention(torch.nn.Module):
def __init__(
self,
n_mel_channels=80,
n_text_channels=512,
n_att_channels=80,
temperature=1.0,
use_query_proj=True,
):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.att_scaling_factor = np.sqrt(n_att_channels)
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
self.use_query_proj = bool(use_query_proj)
self.key_proj = nn.Sequential(
ConvNorm(
n_text_channels,
n_text_channels * 2,
kernel_size=3,
bias=True,
w_init_gain="relu",
),
torch.nn.ReLU(),
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(
n_mel_channels,
n_mel_channels * 2,
kernel_size=3,
bias=True,
w_init_gain="relu",
),
torch.nn.ReLU(),
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
def forward(self, queries, keys, mask=None, attn_prior=None):
"""Attention mechanism for flowtron parallel
Unlike in Flowtron, we have no restrictions such as causality etc,
since we only need this during training.
Args:
queries (torch.tensor): B x C x T1 tensor
(probably going to be mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
if self.use_query_proj:
queries_enc = self.query_proj(queries)
else:
queries_enc = queries
# different ways of computing attn,
# one is isotopic gaussians (per phoneme)
# Simplistic Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(mask.unsqueeze(1).unsqueeze(1), -float("inf"))
attn = self.softmax(attn) # Softmax along T2
return attn, attn_logprob
"""
FSMN Pytorch Version
"""
import torch.nn as nn
import torch.nn.functional as F
class FeedForwardNet(nn.Module):
""" A two-feed-forward-layer module """
def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1):
super().__init__()
# Use Conv1D
# position-wise
self.w_1 = nn.Conv1d(
d_in,
d_hid,
kernel_size=kernel_size[0],
padding=(kernel_size[0] - 1) // 2,
)
# position-wise
self.w_2 = nn.Conv1d(
d_hid,
d_out,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
bias=False,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
output = x.transpose(1, 2)
output = F.relu(self.w_1(output))
output = self.dropout(output)
output = self.w_2(output)
output = output.transpose(1, 2)
return output
class MemoryBlockV2(nn.Module):
def __init__(self, d, filter_size, shift, dropout=0.0):
super(MemoryBlockV2, self).__init__()
left_padding = int(round((filter_size - 1) / 2))
right_padding = int((filter_size - 1) / 2)
if shift > 0:
left_padding += shift
right_padding -= shift
self.lp, self.rp = left_padding, right_padding
self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, input, mask=None):
if mask is not None:
input = input.masked_fill(mask.unsqueeze(-1), 0)
x = F.pad(input, (0, 0, self.lp, self.rp, 0, 0), mode="constant", value=0.0)
output = (
self.conv_dw(x.contiguous().transpose(1, 2)).contiguous().transpose(1, 2)
)
output += input
output = self.dropout(output)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output
class FsmnEncoderV2(nn.Module):
def __init__(
self,
filter_size,
fsmn_num_layers,
input_dim,
num_memory_units,
ffn_inner_dim,
dropout=0.0,
shift=0,
):
super(FsmnEncoderV2, self).__init__()
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout = dropout
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.ffn_lst = nn.ModuleList()
self.ffn_lst.append(
FeedForwardNet(input_dim, ffn_inner_dim, num_memory_units, dropout=dropout)
)
for i in range(1, fsmn_num_layers):
self.ffn_lst.append(
FeedForwardNet(
num_memory_units, ffn_inner_dim, num_memory_units, dropout=dropout
)
)
self.memory_block_lst = nn.ModuleList()
for i in range(fsmn_num_layers):
self.memory_block_lst.append(
MemoryBlockV2(num_memory_units, filter_size, self.shift[i], dropout)
)
def forward(self, input, mask=None):
x = F.dropout(input, self.dropout, self.training)
for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst):
context = ffn(x)
memory = memory_block(context, mask)
memory = F.dropout(memory, self.dropout, self.training)
if memory.size(-1) == x.size(-1):
memory += x
x = memory
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from kantts.models.sambert import FFTBlock, PNCABlock, Prenet
from kantts.models.sambert.positions import (
SinusoidalPositionEncoder,
DurSinusoidalPositionEncoder,
)
from kantts.models.sambert.adaptors import (
LengthRegulator,
VarFsmnRnnNARPredictor,
VarRnnARPredictor,
)
from kantts.models.sambert.fsmn import FsmnEncoderV2
from kantts.models.sambert.alignment import b_mas
from kantts.models.sambert.attention import ConvAttention
from kantts.models.utils import get_mask_from_lengths
class SelfAttentionEncoder(nn.Module):
def __init__(
self,
n_layer,
d_in,
d_model,
n_head,
d_head,
d_inner,
dropout,
dropout_att,
dropout_relu,
position_encoder,
):
super(SelfAttentionEncoder, self).__init__()
self.d_in = d_in
self.d_model = d_model
self.dropout = dropout
d_in_lst = [d_in] + [d_model] * (n_layer - 1)
self.fft = nn.ModuleList(
[
FFTBlock(
d,
d_model,
n_head,
d_head,
d_inner,
(3, 1),
dropout,
dropout_att,
dropout_relu,
)
for d in d_in_lst
]
)
self.ln = nn.LayerNorm(d_model, eps=1e-6)
self.position_enc = position_encoder
def forward(self, input, mask=None, return_attns=False):
input *= self.d_model ** 0.5
if isinstance(self.position_enc, SinusoidalPositionEncoder):
input = self.position_enc(input)
else:
raise NotImplementedError
input = F.dropout(input, p=self.dropout, training=self.training)
enc_slf_attn_list = []
max_len = input.size(1)
if mask is not None:
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
else:
slf_attn_mask = None
enc_output = input
for id, layer in enumerate(self.fft):
enc_output, enc_slf_attn = layer(
enc_output, mask=mask, slf_attn_mask=slf_attn_mask
)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
enc_output = self.ln(enc_output)
return enc_output, enc_slf_attn_list
class HybridAttentionDecoder(nn.Module):
def __init__(
self,
d_in,
prenet_units,
n_layer,
d_model,
d_mem,
n_head,
d_head,
d_inner,
dropout,
dropout_att,
dropout_relu,
d_out,
):
super(HybridAttentionDecoder, self).__init__()
self.d_model = d_model
self.dropout = dropout
self.prenet = Prenet(d_in, prenet_units, d_model)
self.dec_in_proj = nn.Linear(d_model + d_mem, d_model)
self.pnca = nn.ModuleList(
[
PNCABlock(
d_model,
d_mem,
n_head,
d_head,
d_inner,
(1, 1),
dropout,
dropout_att,
dropout_relu,
)
for _ in range(n_layer)
]
)
self.ln = nn.LayerNorm(d_model, eps=1e-6)
self.dec_out_proj = nn.Linear(d_model, d_out)
def reset_state(self):
for layer in self.pnca:
layer.reset_state()
def get_pnca_attn_mask(
self, device, max_len, x_band_width, h_band_width, mask=None
):
if mask is not None:
pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
else:
pnca_attn_mask = None
range_ = torch.arange(max_len).to(device)
x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :]
x_end = (range_ + 1)[None, None, :]
h_start = range_[None, None, :]
h_end = torch.clamp_max(range_ + h_band_width + 1, max_len + 1)[None, None, :]
pnca_x_attn_mask = ~(
(x_start <= range_[None, :, None]) & (x_end > range_[None, :, None])
).transpose(1, 2)
pnca_h_attn_mask = ~(
(h_start <= range_[None, :, None]) & (h_end > range_[None, :, None])
).transpose(1, 2)
if pnca_attn_mask is not None:
pnca_x_attn_mask = pnca_x_attn_mask | pnca_attn_mask
pnca_h_attn_mask = pnca_h_attn_mask | pnca_attn_mask
pnca_x_attn_mask = pnca_x_attn_mask.masked_fill(
pnca_attn_mask.transpose(1, 2), False
)
pnca_h_attn_mask = pnca_h_attn_mask.masked_fill(
pnca_attn_mask.transpose(1, 2), False
)
return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask
# must call reset_state before
def forward(
self, input, memory, x_band_width, h_band_width, mask=None, return_attns=False
):
input = self.prenet(input)
input = torch.cat([memory, input], dim=-1)
input = self.dec_in_proj(input)
if mask is not None:
input = input.masked_fill(mask.unsqueeze(-1), 0)
input *= self.d_model ** 0.5
input = F.dropout(input, p=self.dropout, training=self.training)
max_len = input.size(1)
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
input.device, max_len, x_band_width, h_band_width, mask
)
dec_pnca_attn_x_list = []
dec_pnca_attn_h_list = []
dec_output = input
for id, layer in enumerate(self.pnca):
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
dec_output,
memory,
mask=mask,
pnca_x_attn_mask=pnca_x_attn_mask,
pnca_h_attn_mask=pnca_h_attn_mask,
)
if return_attns:
dec_pnca_attn_x_list += [dec_pnca_attn_x]
dec_pnca_attn_h_list += [dec_pnca_attn_h]
dec_output = self.ln(dec_output)
dec_output = self.dec_out_proj(dec_output)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
# must call reset_state before when step == 0
def infer(
self,
step,
input,
memory,
x_band_width,
h_band_width,
mask=None,
return_attns=False,
):
max_len = memory.size(1)
input = self.prenet(input)
input = torch.cat([memory[:, step : step + 1, :], input], dim=-1)
input = self.dec_in_proj(input)
input *= self.d_model ** 0.5
input = F.dropout(input, p=self.dropout, training=self.training)
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
input.device, max_len, x_band_width, h_band_width, mask
)
dec_pnca_attn_x_list = []
dec_pnca_attn_h_list = []
dec_output = input
for id, layer in enumerate(self.pnca):
if mask is not None:
mask_step = mask[:, step : step + 1]
else:
mask_step = None
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
dec_output,
memory,
mask=mask_step,
pnca_x_attn_mask=pnca_x_attn_mask[:, step : step + 1, : (step + 1)],
pnca_h_attn_mask=pnca_h_attn_mask[:, step : step + 1, :],
)
if return_attns:
dec_pnca_attn_x_list += [dec_pnca_attn_x]
dec_pnca_attn_h_list += [dec_pnca_attn_h]
dec_output = self.ln(dec_output)
dec_output = self.dec_out_proj(dec_output)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
class TextFftEncoder(nn.Module):
def __init__(self, config):
super(TextFftEncoder, self).__init__()
d_emb = config["embedding_dim"]
self.using_byte = False
if config.get("using_byte", False):
self.using_byte = True
nb_ling_byte_index = config["byte_index"]
self.byte_index_emb = nn.Embedding(nb_ling_byte_index, d_emb)
else:
# linguistic unit lookup table
nb_ling_sy = config["sy"]
nb_ling_tone = config["tone"]
nb_ling_syllable_flag = config["syllable_flag"]
nb_ling_ws = config["word_segment"]
self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
max_len = config["max_len"]
nb_layers = config["encoder_num_layers"]
nb_heads = config["encoder_num_heads"]
d_model = config["encoder_num_units"]
d_head = d_model // nb_heads
d_inner = config["encoder_ffn_inner_dim"]
dropout = config["encoder_dropout"]
dropout_attn = config["encoder_attention_dropout"]
dropout_relu = config["encoder_relu_dropout"]
d_proj = config["encoder_projection_units"]
self.d_model = d_model
position_enc = SinusoidalPositionEncoder(max_len, d_emb)
self.ling_enc = SelfAttentionEncoder(
nb_layers,
d_emb,
d_model,
nb_heads,
d_head,
d_inner,
dropout,
dropout_attn,
dropout_relu,
position_enc,
)
self.ling_proj = nn.Linear(d_model, d_proj, bias=False)
def forward(self, inputs_ling, masks=None, return_attns=False):
# Parse inputs_ling_seq
if self.using_byte:
inputs_byte_index = inputs_ling[:, :, 0]
byte_index_embedding = self.byte_index_emb(inputs_byte_index)
ling_embedding = byte_index_embedding
else:
inputs_sy = inputs_ling[:, :, 0]
inputs_tone = inputs_ling[:, :, 1]
inputs_syllable_flag = inputs_ling[:, :, 2]
inputs_ws = inputs_ling[:, :, 3]
# Lookup table
sy_embedding = self.sy_emb(inputs_sy)
tone_embedding = self.tone_emb(inputs_tone)
syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag)
ws_embedding = self.ws_emb(inputs_ws)
ling_embedding = (
sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding
)
enc_output, enc_slf_attn_list = self.ling_enc(
ling_embedding, masks, return_attns
)
if hasattr(self, "ling_proj"):
enc_output = self.ling_proj(enc_output)
return enc_output, enc_slf_attn_list, ling_embedding
class VarianceAdaptor(nn.Module):
def __init__(self, config):
super(VarianceAdaptor, self).__init__()
input_dim = (
config["encoder_projection_units"]
+ config["emotion_units"]
+ config["speaker_units"]
)
filter_size = config["predictor_filter_size"]
fsmn_num_layers = config["predictor_fsmn_num_layers"]
num_memory_units = config["predictor_num_memory_units"]
ffn_inner_dim = config["predictor_ffn_inner_dim"]
dropout = config["predictor_dropout"]
shift = config["predictor_shift"]
lstm_units = config["predictor_lstm_units"]
dur_pred_prenet_units = config["dur_pred_prenet_units"]
dur_pred_lstm_units = config["dur_pred_lstm_units"]
self.pitch_predictor = VarFsmnRnnNARPredictor(
input_dim,
filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
lstm_units,
)
self.energy_predictor = VarFsmnRnnNARPredictor(
input_dim,
filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
lstm_units,
)
self.duration_predictor = VarRnnARPredictor(
input_dim, dur_pred_prenet_units, dur_pred_lstm_units
)
self.length_regulator = LengthRegulator(config["outputs_per_step"])
self.dur_position_encoder = DurSinusoidalPositionEncoder(
config["encoder_projection_units"], config["outputs_per_step"]
)
self.pitch_emb = nn.Conv1d(
1, config["encoder_projection_units"], kernel_size=9, padding=4
)
self.energy_emb = nn.Conv1d(
1, config["encoder_projection_units"], kernel_size=9, padding=4
)
def forward(
self,
inputs_text_embedding,
inputs_emo_embedding,
inputs_spk_embedding,
masks=None,
output_masks=None,
duration_targets=None,
pitch_targets=None,
energy_targets=None,
):
batch_size = inputs_text_embedding.size(0)
variance_predictor_inputs = torch.cat(
[inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding], dim=-1
)
pitch_predictions = self.pitch_predictor(variance_predictor_inputs, masks)
energy_predictions = self.energy_predictor(variance_predictor_inputs, masks)
if pitch_targets is not None:
pitch_embeddings = self.pitch_emb(pitch_targets.unsqueeze(1)).transpose(
1, 2
)
else:
pitch_embeddings = self.pitch_emb(pitch_predictions.unsqueeze(1)).transpose(
1, 2
)
if energy_targets is not None:
energy_embeddings = self.energy_emb(energy_targets.unsqueeze(1)).transpose(
1, 2
)
else:
energy_embeddings = self.energy_emb(
energy_predictions.unsqueeze(1)
).transpose(1, 2)
inputs_text_embedding_aug = (
inputs_text_embedding + pitch_embeddings + energy_embeddings
)
duration_predictor_cond = torch.cat(
[inputs_text_embedding_aug, inputs_spk_embedding, inputs_emo_embedding],
dim=-1,
)
if duration_targets is not None:
duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
inputs_text_embedding.device
)
duration_predictor_input = torch.cat(
[duration_predictor_go_frame, duration_targets[:, :-1].float()], dim=-1
)
duration_predictor_input = torch.log(duration_predictor_input + 1)
log_duration_predictions, _ = self.duration_predictor(
duration_predictor_input.unsqueeze(-1),
duration_predictor_cond,
masks=masks,
)
duration_predictions = torch.exp(log_duration_predictions) - 1
else:
log_duration_predictions = self.duration_predictor.infer(
duration_predictor_cond, masks=masks
)
duration_predictions = torch.exp(log_duration_predictions) - 1
if duration_targets is not None:
LR_text_outputs, LR_length_rounded = self.length_regulator(
inputs_text_embedding_aug, duration_targets, masks=output_masks
)
LR_position_embeddings = self.dur_position_encoder(
duration_targets, masks=output_masks
)
LR_emo_outputs, _ = self.length_regulator(
inputs_emo_embedding, duration_targets, masks=output_masks
)
LR_spk_outputs, _ = self.length_regulator(
inputs_spk_embedding, duration_targets, masks=output_masks
)
else:
LR_text_outputs, LR_length_rounded = self.length_regulator(
inputs_text_embedding_aug, duration_predictions, masks=output_masks
)
LR_position_embeddings = self.dur_position_encoder(
duration_predictions, masks=output_masks
)
LR_emo_outputs, _ = self.length_regulator(
inputs_emo_embedding, duration_predictions, masks=output_masks
)
LR_spk_outputs, _ = self.length_regulator(
inputs_spk_embedding, duration_predictions, masks=output_masks
)
LR_text_outputs = LR_text_outputs + LR_position_embeddings
return (
LR_text_outputs,
LR_emo_outputs,
LR_spk_outputs,
LR_length_rounded,
log_duration_predictions,
pitch_predictions,
energy_predictions,
)
class MelPNCADecoder(nn.Module):
def __init__(self, config):
super(MelPNCADecoder, self).__init__()
prenet_units = config["decoder_prenet_units"]
nb_layers = config["decoder_num_layers"]
nb_heads = config["decoder_num_heads"]
d_model = config["decoder_num_units"]
d_head = d_model // nb_heads
d_inner = config["decoder_ffn_inner_dim"]
dropout = config["decoder_dropout"]
dropout_attn = config["decoder_attention_dropout"]
dropout_relu = config["decoder_relu_dropout"]
outputs_per_step = config["outputs_per_step"]
d_mem = (
config["encoder_projection_units"] * outputs_per_step
+ config["emotion_units"]
+ config["speaker_units"]
)
d_mel = config["num_mels"]
self.d_mel = d_mel
self.r = outputs_per_step
self.nb_layers = nb_layers
self.mel_dec = HybridAttentionDecoder(
d_mel,
prenet_units,
nb_layers,
d_model,
d_mem,
nb_heads,
d_head,
d_inner,
dropout,
dropout_attn,
dropout_relu,
d_mel * outputs_per_step,
)
def forward(
self,
memory,
x_band_width,
h_band_width,
target=None,
mask=None,
return_attns=False,
):
batch_size = memory.size(0)
go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device)
if target is not None:
self.mel_dec.reset_state()
input = target[:, self.r - 1 :: self.r, :]
input = torch.cat([go_frame, input], dim=1)[:, :-1, :]
dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec(
input,
memory,
x_band_width,
h_band_width,
mask=mask,
return_attns=return_attns,
)
else:
dec_output = []
dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)]
dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)]
self.mel_dec.reset_state()
input = go_frame
for step in range(memory.size(1)):
(
dec_output_step,
dec_pnca_attn_x_step,
dec_pnca_attn_h_step,
) = self.mel_dec.infer(
step,
input,
memory,
x_band_width,
h_band_width,
mask=mask,
return_attns=return_attns,
)
input = dec_output_step[:, :, -self.d_mel :]
dec_output.append(dec_output_step)
for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)
):
left = memory.size(1) - pnca_x_attn.size(-1)
if left > 0:
padding = torch.zeros((pnca_x_attn.size(0), 1, left)).to(
pnca_x_attn
)
pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1)
dec_pnca_attn_x_list[layer_id].append(pnca_x_attn)
dec_pnca_attn_h_list[layer_id].append(pnca_h_attn)
dec_output = torch.cat(dec_output, dim=1)
for layer_id in range(self.nb_layers):
dec_pnca_attn_x_list[layer_id] = torch.cat(
dec_pnca_attn_x_list[layer_id], dim=1
)
dec_pnca_attn_h_list[layer_id] = torch.cat(
dec_pnca_attn_h_list[layer_id], dim=1
)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
class PostNet(nn.Module):
def __init__(self, config):
super(PostNet, self).__init__()
self.filter_size = config["postnet_filter_size"]
self.fsmn_num_layers = config["postnet_fsmn_num_layers"]
self.num_memory_units = config["postnet_num_memory_units"]
self.ffn_inner_dim = config["postnet_ffn_inner_dim"]
self.dropout = config["postnet_dropout"]
self.shift = config["postnet_shift"]
self.lstm_units = config["postnet_lstm_units"]
self.num_mels = config["num_mels"]
self.fsmn = FsmnEncoderV2(
self.filter_size,
self.fsmn_num_layers,
self.num_mels,
self.num_memory_units,
self.ffn_inner_dim,
self.dropout,
self.shift,
)
self.lstm = nn.LSTM(
self.num_memory_units, self.lstm_units, num_layers=1, batch_first=True
)
self.fc = nn.Linear(self.lstm_units, self.num_mels)
def forward(self, x, mask=None):
postnet_fsmn_output = self.fsmn(x, mask)
# The input can also be a packed variable length sequence,
# here we just omit it for simpliciy due to the mask and uni-directional lstm.
postnet_lstm_output, _ = self.lstm(postnet_fsmn_output)
mel_residual_output = self.fc(postnet_lstm_output)
return mel_residual_output
def average_frame_feat(pitch, durs):
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = F.pad(durs_cums_ends[:, :-1], (1, 0))
pitch_nonzero_cums = F.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
pitch_cums = F.pad(torch.cumsum(pitch, dim=2), (1, 0))
bs, lengths = durs_cums_ends.size()
n_formants = pitch.size(1)
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, lengths)
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, lengths)
pitch_sums = (
torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)
).float()
pitch_nelems = (
torch.gather(pitch_nonzero_cums, 2, dce)
- torch.gather(pitch_nonzero_cums, 2, dcs)
).float()
pitch_avg = torch.where(
pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems
)
return pitch_avg
class FP_Predictor(nn.Module):
def __init__(self, config):
super(FP_Predictor, self).__init__()
self.w_1 = nn.Conv1d(
config["encoder_projection_units"],
config["embedding_dim"] // 2,
kernel_size=3,
padding=1,
)
self.w_2 = nn.Conv1d(
config["embedding_dim"] // 2,
config["encoder_projection_units"],
kernel_size=1,
padding=0,
)
self.layer_norm1 = nn.LayerNorm(config["embedding_dim"] // 2, eps=1e-6)
self.layer_norm2 = nn.LayerNorm(config["encoder_projection_units"], eps=1e-6)
self.dropout_inner = nn.Dropout(0.1)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(config["encoder_projection_units"], 4)
def forward(self, x):
x = x.transpose(1, 2)
x = F.relu(self.w_1(x))
x = x.transpose(1, 2)
x = self.dropout_inner(self.layer_norm1(x))
x = x.transpose(1, 2)
x = F.relu(self.w_2(x))
x = x.transpose(1, 2)
x = self.dropout(self.layer_norm2(x))
output = F.softmax(self.fc(x), dim=2)
return output
class KanTtsSAMBERT(nn.Module):
def __init__(self, config):
super(KanTtsSAMBERT, self).__init__()
self.text_encoder = TextFftEncoder(config)
self.se_enable = config.get("SE", False)
if not self.se_enable:
self.spk_tokenizer = nn.Embedding(config["speaker"], config["speaker_units"])
self.emo_tokenizer = nn.Embedding(config["emotion"], config["emotion_units"])
self.variance_adaptor = VarianceAdaptor(config)
self.mel_decoder = MelPNCADecoder(config)
self.mel_postnet = PostNet(config)
self.MAS = False
if config.get("MAS", False):
self.MAS = True
self.align_attention = ConvAttention(
n_mel_channels=config["num_mels"],
n_text_channels=config["embedding_dim"],
n_att_channels=config["num_mels"],
)
self.fp_enable = config.get("FP", False)
if self.fp_enable:
self.FP_predictor = FP_Predictor(config)
def get_lfr_mask_from_lengths(self, lengths, max_len):
batch_size = lengths.size(0)
# padding according to the outputs_per_step
padded_lr_lengths = torch.zeros_like(lengths)
for i in range(batch_size):
len_item = int(lengths[i].item())
padding = self.mel_decoder.r - len_item % self.mel_decoder.r
if padding < self.mel_decoder.r:
padded_lr_lengths[i] = (len_item + padding) // self.mel_decoder.r
else:
padded_lr_lengths[i] = len_item // self.mel_decoder.r
return get_mask_from_lengths(
padded_lr_lengths, max_len=max_len // self.mel_decoder.r
)
def binarize_attention_parallel(self, attn, in_lens, out_lens):
"""For training purposes only. Binarizes attention with MAS.
These will no longer recieve a gradient.
Args:
attn: B x 1 x max_mel_len x max_text_len
"""
with torch.no_grad():
attn_cpu = attn.data.cpu().numpy()
attn_out = b_mas(
attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1
)
return torch.from_numpy(attn_out).to(attn.get_device())
def insert_fp(
self,
text_hid,
FP_p,
fp_label,
fp_dict,
inputs_emotion,
inputs_speaker,
input_lengths,
input_masks,
):
en, _, _ = self.text_encoder(fp_dict[1], return_attns=True)
a, _, _ = self.text_encoder(fp_dict[2], return_attns=True)
e, _, _ = self.text_encoder(fp_dict[3], return_attns=True)
en = en.squeeze()
a = a.squeeze()
e = e.squeeze()
max_len_ori = max(input_lengths)
if fp_label is None:
input_masks_r = ~input_masks
fp_mask = (FP_p == FP_p.max(dim=2, keepdim=True)[0]).to(dtype=torch.int32)
fp_mask = fp_mask[:, :, 1:] * input_masks_r.unsqueeze(2).expand(-1, -1, 3)
fp_number = torch.sum(torch.sum(fp_mask, dim=2), dim=1)
else:
fp_number = torch.sum((fp_label > 0), dim=1)
inter_lengths = input_lengths + 3 * fp_number
max_len = max(inter_lengths)
delta = max_len - max_len_ori
if delta > 0:
if delta > text_hid.shape[1]:
nrepeat = delta // text_hid.shape[1]
bias = delta % text_hid.shape[1]
text_hid = torch.cat(
(text_hid, text_hid.repeat(1, nrepeat, 1), text_hid[:, :bias, :]), 1
)
inputs_emotion = torch.cat(
(
inputs_emotion,
inputs_emotion.repeat(1, nrepeat),
inputs_emotion[:, :bias],
),
1,
)
inputs_speaker = torch.cat(
(
inputs_speaker,
inputs_speaker.repeat(1, nrepeat),
inputs_speaker[:, :bias],
),
1,
)
else:
text_hid = torch.cat((text_hid, text_hid[:, :delta, :]), 1)
inputs_emotion = torch.cat(
(inputs_emotion, inputs_emotion[:, :delta]), 1
)
inputs_speaker = torch.cat(
(inputs_speaker, inputs_speaker[:, :delta]), 1
)
if fp_label is None:
for i in range(fp_mask.shape[0]):
for j in range(fp_mask.shape[1] - 1, -1, -1):
if fp_mask[i][j][0] == 1:
text_hid[i] = torch.cat(
(text_hid[i][:j], en, text_hid[i][j:-3]), 0
)
elif fp_mask[i][j][1] == 1:
text_hid[i] = torch.cat(
(text_hid[i][:j], a, text_hid[i][j:-3]), 0
)
elif fp_mask[i][j][2] == 1:
text_hid[i] = torch.cat(
(text_hid[i][:j], e, text_hid[i][j:-3]), 0
)
else:
for i in range(fp_label.shape[0]):
for j in range(fp_label.shape[1] - 1, -1, -1):
if fp_label[i][j] == 1:
text_hid[i] = torch.cat(
(text_hid[i][:j], en, text_hid[i][j:-3]), 0
)
elif fp_label[i][j] == 2:
text_hid[i] = torch.cat(
(text_hid[i][:j], a, text_hid[i][j:-3]), 0
)
elif fp_label[i][j] == 3:
text_hid[i] = torch.cat(
(text_hid[i][:j], e, text_hid[i][j:-3]), 0
)
return text_hid, inputs_emotion, inputs_speaker, inter_lengths
def forward(
self,
inputs_ling,
inputs_emotion,
inputs_speaker,
input_lengths,
output_lengths=None,
mel_targets=None,
duration_targets=None,
pitch_targets=None,
energy_targets=None,
attn_priors=None,
fp_label=None,
):
batch_size = inputs_ling.size(0)
is_training = mel_targets is not None
input_masks = get_mask_from_lengths(input_lengths, max_len=inputs_ling.size(1))
text_hid, enc_sla_attn_lst, ling_embedding = self.text_encoder(
inputs_ling, input_masks, return_attns=True
)
inter_lengths = input_lengths
FP_p = None
if self.fp_enable:
FP_p = self.FP_predictor(text_hid)
fp_dict = self.fp_dict
text_hid, inputs_emotion, inputs_speaker, inter_lengths = self.insert_fp(
text_hid,
FP_p,
fp_label,
fp_dict,
inputs_emotion,
inputs_speaker,
input_lengths,
input_masks,
)
# Monotonic-Alignment-Search
if self.MAS and is_training:
attn_soft, attn_logprob = self.align_attention(
mel_targets.permute(0, 2, 1),
ling_embedding.permute(0, 2, 1),
input_masks,
attn_priors,
)
attn_hard = self.binarize_attention_parallel(
attn_soft, input_lengths, output_lengths
)
attn_hard_dur = attn_hard.sum(2)[:, 0, :]
duration_targets = attn_hard_dur
assert torch.all(torch.eq(duration_targets.sum(dim=1), output_lengths))
pitch_targets = average_frame_feat(
pitch_targets.unsqueeze(1), duration_targets
).squeeze(1)
energy_targets = average_frame_feat(
energy_targets.unsqueeze(1), duration_targets
).squeeze(1)
# Padding the POS length to make it sum equal to max rounded output length
for i in range(batch_size):
len_item = int(output_lengths[i].item())
padding = mel_targets.size(1) - len_item
duration_targets[i, input_lengths[i]] = padding
emo_hid = self.emo_tokenizer(inputs_emotion)
spk_hid = inputs_speaker if self.se_enable else self.spk_tokenizer(inputs_speaker)
inter_masks = get_mask_from_lengths(inter_lengths, max_len=text_hid.size(1))
if output_lengths is not None:
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1)
)
else:
output_masks = None
(
LR_text_outputs,
LR_emo_outputs,
LR_spk_outputs,
LR_length_rounded,
log_duration_predictions,
pitch_predictions,
energy_predictions,
) = self.variance_adaptor(
text_hid,
emo_hid,
spk_hid,
masks=inter_masks,
output_masks=output_masks,
duration_targets=duration_targets,
pitch_targets=pitch_targets,
energy_targets=energy_targets,
)
if output_lengths is not None:
lfr_masks = self.get_lfr_mask_from_lengths(
output_lengths, max_len=LR_text_outputs.size(1)
)
else:
output_masks = get_mask_from_lengths(
LR_length_rounded, max_len=LR_text_outputs.size(1)
)
lfr_masks = None
# LFR with the factor of outputs_per_step
LFR_text_inputs = LR_text_outputs.contiguous().view(
batch_size, -1, self.mel_decoder.r * text_hid.shape[-1]
)
LFR_emo_inputs = LR_emo_outputs.contiguous().view(
batch_size, -1, self.mel_decoder.r * emo_hid.shape[-1]
)[:, :, : emo_hid.shape[-1]]
LFR_spk_inputs = LR_spk_outputs.contiguous().view(
batch_size, -1, self.mel_decoder.r * spk_hid.shape[-1]
)[:, :, : spk_hid.shape[-1]]
memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], dim=-1)
if duration_targets is not None:
x_band_width = int(
duration_targets.float().masked_fill(inter_masks, 0).max()
/ self.mel_decoder.r
+ 0.5
)
h_band_width = x_band_width
else:
x_band_width = int(
(torch.exp(log_duration_predictions) - 1).max() / self.mel_decoder.r
+ 0.5
)
h_band_width = x_band_width
dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder(
memory,
x_band_width,
h_band_width,
target=mel_targets,
mask=lfr_masks,
return_attns=True,
)
# De-LFR with the factor of outputs_per_step
dec_outputs = dec_outputs.contiguous().view(
batch_size, -1, self.mel_decoder.d_mel
)
if output_masks is not None:
dec_outputs = dec_outputs.masked_fill(output_masks.unsqueeze(-1), 0)
postnet_outputs = self.mel_postnet(dec_outputs, output_masks) + dec_outputs
if output_masks is not None:
postnet_outputs = postnet_outputs.masked_fill(output_masks.unsqueeze(-1), 0)
res = {
"x_band_width": x_band_width,
"h_band_width": h_band_width,
"enc_slf_attn_lst": enc_sla_attn_lst,
"pnca_x_attn_lst": pnca_x_attn_lst,
"pnca_h_attn_lst": pnca_h_attn_lst,
"dec_outputs": dec_outputs,
"postnet_outputs": postnet_outputs,
"LR_length_rounded": LR_length_rounded,
"log_duration_predictions": log_duration_predictions,
"pitch_predictions": pitch_predictions,
"energy_predictions": energy_predictions,
"duration_targets": duration_targets,
"pitch_targets": pitch_targets,
"energy_targets": energy_targets,
"fp_predictions": FP_p,
"valid_inter_lengths": inter_lengths,
}
res["LR_text_outputs"] = LR_text_outputs
res["LR_emo_outputs"] = LR_emo_outputs
res["LR_spk_outputs"] = LR_spk_outputs
if self.MAS and is_training:
res["attn_soft"] = attn_soft
res["attn_hard"] = attn_hard
res["attn_logprob"] = attn_logprob
return res
class KanTtsTextsyBERT(nn.Module):
def __init__(self, config):
super(KanTtsTextsyBERT, self).__init__()
self.text_encoder = TextFftEncoder(config)
delattr(self.text_encoder, "ling_proj")
self.fc = nn.Linear(self.text_encoder.d_model, config["sy"])
def forward(self, inputs_ling, input_lengths):
res = {}
input_masks = get_mask_from_lengths(input_lengths, max_len=inputs_ling.size(1))
text_hid, enc_sla_attn_lst = self.text_encoder(
inputs_ling, input_masks, return_attns=True
)
logits = self.fc(text_hid)
res["logits"] = logits
res["enc_slf_attn_lst"] = enc_sla_attn_lst
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment