Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os, sys
import os.path as osp
import numpy as np
import torch
from torch import nn
from torch.optim import Optimizer
from functools import reduce
from torch.optim import AdamW
class MultiOptimizer:
def __init__(self, optimizers={}, schedulers={}):
self.optimizers = optimizers
self.schedulers = schedulers
self.keys = list(optimizers.keys())
self.param_groups = reduce(
lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
)
def state_dict(self):
state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
return state_dicts
def scheduler_state_dict(self):
state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
return state_dicts
def load_state_dict(self, state_dict):
for key, val in state_dict:
try:
self.optimizers[key].load_state_dict(val)
except:
print("Unloaded %s" % key)
def load_scheduler_state_dict(self, state_dict):
for key, val in state_dict:
try:
self.schedulers[key].load_state_dict(val)
except:
print("Unloaded %s" % key)
def step(self, key=None, scaler=None):
keys = [key] if key is not None else self.keys
_ = [self._step(key, scaler) for key in keys]
def _step(self, key, scaler=None):
if scaler is not None:
scaler.step(self.optimizers[key])
scaler.update()
else:
self.optimizers[key].step()
def zero_grad(self, key=None):
if key is not None:
self.optimizers[key].zero_grad()
else:
_ = [self.optimizers[key].zero_grad() for key in self.keys]
def scheduler(self, *args, key=None):
if key is not None:
self.schedulers[key].step(*args)
else:
_ = [self.schedulers[key].step_batch(*args) for key in self.keys]
def define_scheduler(optimizer, params):
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
return scheduler
def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
optim = {}
for key, model in model_dict.items():
model_parameters = model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
if type == "AdamW":
optim[key] = AdamW(
model_parameters,
lr=lr,
betas=(0.9, 0.98),
eps=1e-9,
weight_decay=0.1,
)
else:
raise ValueError("Unknown optimizer type: %s" % type)
schedulers = dict(
[
(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
for key, opt in optim.items()
]
)
multi_optim = MultiOptimizer(optim, schedulers)
return multi_optim
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ALL_COMPLETED
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ResidualVQ
from indextts.utils.maskgct.models.codec.kmeans.vocos import VocosBackbone
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def compute_codebook_perplexity(indices, codebook_size):
indices = indices.flatten()
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
return perp
class RepCodec(nn.Module):
def __init__(
self,
codebook_size=8192,
hidden_size=1024,
codebook_dim=8,
vocos_dim=384,
vocos_intermediate_dim=2048,
vocos_num_layers=12,
num_quantizers=1,
downsample_scale=1,
cfg=None,
):
super().__init__()
codebook_size = (
cfg.codebook_size
if cfg is not None and hasattr(cfg, "codebook_size")
else codebook_size
)
codebook_dim = (
cfg.codebook_dim
if cfg is not None and hasattr(cfg, "codebook_dim")
else codebook_dim
)
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
vocos_dim = (
cfg.vocos_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_dim
)
vocos_intermediate_dim = (
cfg.vocos_intermediate_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_intermediate_dim
)
vocos_num_layers = (
cfg.vocos_num_layers
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_num_layers
)
num_quantizers = (
cfg.num_quantizers
if cfg is not None and hasattr(cfg, "num_quantizers")
else num_quantizers
)
downsample_scale = (
cfg.downsample_scale
if cfg is not None and hasattr(cfg, "downsample_scale")
else downsample_scale
)
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.hidden_size = hidden_size
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.num_quantizers = num_quantizers
self.downsample_scale = downsample_scale
if self.downsample_scale != None and self.downsample_scale > 1:
self.down = nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
)
self.up = nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
)
self.encoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.vocos_dim,
intermediate_dim=self.vocos_intermediate_dim,
num_layers=self.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.vocos_dim, self.hidden_size),
)
self.decoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.vocos_dim,
intermediate_dim=self.vocos_intermediate_dim,
num_layers=self.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.vocos_dim, self.hidden_size),
)
self.quantizer = ResidualVQ(
input_dim=hidden_size,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type="fvq",
quantizer_dropout=0.0,
commitment=0.15,
codebook_loss_weight=1.0,
use_l2_normlize=True,
)
self.reset_parameters()
def forward(self, x):
# downsample
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = self.down(x)
x = F.gelu(x)
x = x.transpose(1, 2)
# encoder
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
# vq
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
# decoder
x = self.decoder(quantized_out)
# up
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = F.interpolate(x, scale_factor=2, mode="nearest")
x_rec = self.up(x).transpose(1, 2)
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return x_rec, codebook_loss, all_indices
def quantize(self, x):
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = self.down(x)
x = F.gelu(x)
x = x.transpose(1, 2)
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
def reset_parameters(self):
self.apply(init_weights)
if __name__ == "__main__":
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
print(repcodec)
print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
x = torch.randn(5, 10, 1024)
x_rec, codebook_loss, all_indices = repcodec(x)
print(x_rec.shape, codebook_loss, all_indices.shape)
vq_id, emb = repcodec.quantize(x)
print(vq_id.shape, emb.shape)
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import numpy as np
import scipy
import torch
from torch import nn, view_as_real, view_as_complex
from torch import nn
from torch.nn.utils import weight_norm, remove_weight_norm
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class STFT(nn.Module):
def __init__(
self,
n_fft: int,
hop_length: int,
win_length: int,
center=True,
):
super().__init__()
self.center = center
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T * hop_length)
if not self.center:
pad = self.win_length - self.hop_length
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
stft_spec = torch.stft(
x,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
return_complex=False,
) # (B, n_fft // 2 + 1, T, 2)
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
log_mag = torch.log(
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
) # (B, n_fft // 2 + 1, T)
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
return log_mag, phase
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope
return y
class MDCT(nn.Module):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if self.padding == "center":
audio = torch.nn.functional.pad(
audio, (self.frame_len // 2, self.frame_len // 2)
)
elif self.padding == "same":
# hop_length is 1/2 frame_len
audio = torch.nn.functional.pad(
audio, (self.frame_len // 4, self.frame_len // 4)
)
else:
raise ValueError("Padding must be 'center' or 'same'.")
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
N = self.frame_len // 2
x = x * self.window.expand(x.shape)
X = torch.fft.fft(
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
)[..., :N]
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
return torch.real(res) * np.sqrt(2)
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
)
y = (
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
* np.sqrt(N)
* np.sqrt(2)
)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class ISTFTHead(FourierHead):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(
mag, max=1e2
) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
sample_rate: Optional[int] = None,
clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
self.out = nn.Linear(dim, out_dim)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
self.clip_audio = clip_audio
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.out.weight.mul_(scale.view(-1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
x = symexp(x)
x = torch.clip(
x, min=-1e2, max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class IMDCTCosHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
clip_audio: bool = False,
):
super().__init__()
self.clip_audio = clip_audio
self.out = nn.Linear(dim, mdct_frame_len)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
m, p = x.chunk(2, dim=2)
m = torch.exp(m).clip(
max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(m * torch.cos(p))
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: float,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
self.shift = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
]
)
self.gamma = nn.ParameterList(
[
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get("bandwidth_id", None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self,
input_channels,
dim,
num_blocks,
layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
)
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
for _ in range(num_blocks)
]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x
class Vocos(nn.Module):
def __init__(
self,
input_channels: int = 256,
dim: int = 384,
intermediate_dim: int = 1152,
num_layers: int = 8,
adanorm_num_embeddings: int = 4,
n_fft: int = 800,
hop_size: int = 200,
padding: str = "same",
):
super().__init__()
self.backbone = VocosBackbone(
input_channels=input_channels,
dim=dim,
intermediate_dim=intermediate_dim,
num_layers=num_layers,
adanorm_num_embeddings=adanorm_num_embeddings,
)
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x[:, None, :]
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import pyworld as pw
import numpy as np
import soundfile as sf
import os
from torchaudio.functional import pitch_shift
import librosa
from librosa.filters import mel as librosa_mel_fn
import torch.nn as nn
import torch.nn.functional as F
import tqdm
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
class MelSpectrogram(nn.Module):
def __init__(
self,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
center=False,
):
super(MelSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_size = hop_size
self.win_size = win_size
self.sampling_rate = sampling_rate
self.num_mels = num_mels
self.fmin = fmin
self.fmax = fmax
self.center = center
mel_basis = {}
hann_window = {}
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis = torch.from_numpy(mel).float()
hann_window = torch.hann_window(win_size)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("hann_window", hann_window)
def forward(self, y):
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec)
return spec
## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2403.03100.pdf)
[![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
[![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
## Overview
FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
<br>
<div align="center">
<img src="../../../imgs/ns3/ns3_overview.png" width="65%">
</div>
<br>
<br>
<div align="center">
<img src="../../../imgs/ns3/ns3_facodec.png" width="100%">
</div>
<br>
## Useage
Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
Install Amphion
```bash
git clone https://github.com/open-mmlab/Amphion.git
```
Few lines of code to use the pre-trained FACodec model
```python
from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
from huggingface_hub import hf_hub_download
fa_encoder = FACodecEncoder(
ngf=32,
up_ratios=[2, 4, 5, 5],
out_channels=256,
)
fa_decoder = FACodecDecoder(
in_channels=256,
upsample_initial_channel=1024,
ngf=32,
up_ratios=[5, 5, 4, 2],
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=256,
codebook_dim=8,
codebook_size_prosody=10,
codebook_size_content=10,
codebook_size_residual=10,
use_gr_x_timbre=True,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
)
encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
fa_encoder.load_state_dict(torch.load(encoder_ckpt))
fa_decoder.load_state_dict(torch.load(decoder_ckpt))
fa_encoder.eval()
fa_decoder.eval()
```
Inference
```python
test_wav_path = "test.wav"
test_wav = librosa.load(test_wav_path, sr=16000)[0]
test_wav = torch.from_numpy(test_wav).float()
test_wav = test_wav.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
# encode
enc_out = fa_encoder(test_wav)
print(enc_out.shape)
# quantize
vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
# latent after quantization
print(vq_post_emb.shape)
# codes
print("vq id shape:", vq_id.shape)
# get prosody code
prosody_code = vq_id[:1]
print("prosody code shape:", prosody_code.shape)
# get content code
cotent_code = vq_id[1:3]
print("content code shape:", cotent_code.shape)
# get residual code (acoustic detail codes)
residual_code = vq_id[3:]
print("residual code shape:", residual_code.shape)
# speaker embedding
print("speaker embedding shape:", spk_embs.shape)
# decode (recommand)
recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
print(recon_wav.shape)
sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
```
FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder
```python
from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
# Same parameters as FACodecEncoder/FACodecDecoder
fa_encoder_v2 = FACodecEncoderV2(...)
fa_decoder_v2 = FACodecDecoderV2(...)
encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin")
decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin")
fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))
with torch.no_grad():
enc_out_a = fa_encoder_v2(wav_a)
prosody_a = fa_encoder_v2.get_prosody_feature(wav_a)
enc_out_b = fa_encoder_v2(wav_b)
prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)
vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2(
enc_out_a, prosody_a, eval_vq=False, vq=True
)
vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
enc_out_b, prosody_b, eval_vq=False, vq=True
)
vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)
```
or
```python
from Amphion.models.codec.ns3_codec import FACodecRedecoder
fa_redecoder = FACodecRedecoder()
redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin")
fa_redecoder.load_state_dict(torch.load(redecoder_ckpt))
with torch.no_grad():
enc_out_a = fa_encoder(wav_a)
enc_out_b = fa_encoder(wav_b)
vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True)
vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True)
# convert speaker
vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False)
recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b)
sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
```
## Q&A
Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
## Citations
If you use our FACodec model, please cite the following paper:
```bibtex
@article{ju2024naturalspeech,
title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others},
journal={arXiv preprint arXiv:2403.03100},
year={2024}
}
@article{zhang2023amphion,
title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
journal={arXiv},
year={2024},
volume={abs/2312.09911}
}
```
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .facodec import *
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
from .filter import *
from .resample import *
from .act import *
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
# input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return out
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
xx = self.lowpass(x)
return xx
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from .alias_free_torch import *
from .quantize import *
from einops import rearrange
from einops.layers.torch import Rearrange
from .transformer import TransformerEncoder
from .gradient_reversal import GradientReversal
from .melspec import MelSpectrogram
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CNNLSTM(nn.Module):
def __init__(self, indim, outdim, head, global_pred=False):
super().__init__()
self.global_pred = global_pred
self.model = nn.Sequential(
ResidualUnit(indim, dilation=1),
ResidualUnit(indim, dilation=2),
ResidualUnit(indim, dilation=3),
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
Rearrange("b c t -> b t c"),
)
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
def forward(self, x):
# x: [B, C, T]
x = self.model(x)
if self.global_pred:
x = torch.mean(x, dim=1, keepdim=False)
outs = [head(x) for head in self.heads]
return outs
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta := x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
return x + self.block(x)
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
),
)
def forward(self, x):
return self.block(x)
class FACodecEncoder(nn.Module):
def __init__(
self,
ngf=32,
up_ratios=(2, 4, 5, 5),
out_channels=1024,
):
super().__init__()
self.hop_length = np.prod(up_ratios)
self.up_ratios = up_ratios
# Create first convolution
d_model = ngf
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in up_ratios:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
self.reset_parameters()
def forward(self, x):
out = self.block(x)
return out
def inference(self, x):
return self.block(x)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, nn.Conv1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def reset_parameters(self):
self.apply(init_weights)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class FACodecDecoder(nn.Module):
def __init__(
self,
in_channels=256,
upsample_initial_channel=1536,
ngf=32,
up_ratios=(5, 5, 4, 2),
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=1024,
vq_commit_weight=0.005,
vq_weight_init=False,
vq_full_commit_loss=False,
codebook_dim=8,
codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
codebook_size_content=10,
codebook_size_residual=10,
quantizer_dropout=0.0,
dropout_type="linear",
use_gr_content_f0=False,
use_gr_prosody_phone=False,
use_gr_residual_f0=False,
use_gr_residual_phone=False,
use_gr_x_timbre=False,
use_random_mask_residual=True,
prob_random_mask_residual=0.75,
):
super().__init__()
self.hop_length = np.prod(up_ratios)
self.ngf = ngf
self.up_ratios = up_ratios
self.use_random_mask_residual = use_random_mask_residual
self.prob_random_mask_residual = prob_random_mask_residual
self.vq_num_q_p = vq_num_q_p
self.vq_num_q_c = vq_num_q_c
self.vq_num_q_r = vq_num_q_r
self.codebook_size_prosody = codebook_size_prosody
self.codebook_size_content = codebook_size_content
self.codebook_size_residual = codebook_size_residual
quantizer_class = ResidualVQ
self.quantizer = nn.ModuleList()
# prosody
quantizer = quantizer_class(
num_quantizers=vq_num_q_p,
dim=vq_dim,
codebook_size=codebook_size_prosody,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# phone
quantizer = quantizer_class(
num_quantizers=vq_num_q_c,
dim=vq_dim,
codebook_size=codebook_size_content,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# residual
if self.vq_num_q_r > 0:
quantizer = quantizer_class(
num_quantizers=vq_num_q_r,
dim=vq_dim,
codebook_size=codebook_size_residual,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# Add first conv layer
channels = upsample_initial_channel
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(up_ratios):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
self.timbre_encoder = TransformerEncoder(
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_cln=False,
)
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
self.timbre_linear.bias.data[:in_channels] = 1
self.timbre_linear.bias.data[in_channels:] = 0
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
self.f0_predictor = CNNLSTM(in_channels, 1, 2)
self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
self.use_gr_content_f0 = use_gr_content_f0
self.use_gr_prosody_phone = use_gr_prosody_phone
self.use_gr_residual_f0 = use_gr_residual_f0
self.use_gr_residual_phone = use_gr_residual_phone
self.use_gr_x_timbre = use_gr_x_timbre
if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
self.res_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
)
if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
self.res_phone_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
)
if self.use_gr_content_f0:
self.content_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
)
if self.use_gr_prosody_phone:
self.prosody_phone_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
)
if self.use_gr_x_timbre:
self.x_timbre_predictor = nn.Sequential(
GradientReversal(alpha=1),
CNNLSTM(in_channels, 245200, 1, global_pred=True),
)
self.reset_parameters()
def quantize(self, x, n_quantizers=None):
outs, qs, commit_loss, quantized_buf = 0, [], [], []
# prosody
f0_input = x # (B, d, T)
f0_quantizer = self.quantizer[0]
out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0))
commit_loss.append(commit)
# phone
phone_input = x
phone_quantizer = self.quantizer[1]
out, q, commit, quantized = phone_quantizer(
phone_input, n_quantizers=n_quantizers
)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0))
commit_loss.append(commit)
# residual
if self.vq_num_q_r > 0:
residual_quantizer = self.quantizer[2]
residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
out, q, commit, quantized = residual_quantizer(
residual_input, n_quantizers=n_quantizers
)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
commit_loss.append(commit)
qs = torch.cat(qs, dim=0)
commit_loss = torch.cat(commit_loss, dim=0)
return outs, qs, commit_loss, quantized_buf
def forward(
self,
x,
vq=True,
get_vq=False,
eval_vq=True,
speaker_embedding=None,
n_quantizers=None,
quantized=None,
):
if get_vq:
return self.quantizer.get_emb()
if vq is True:
if eval_vq:
self.quantizer.eval()
x_timbre = x
outs, qs, commit_loss, quantized_buf = self.quantize(
x, n_quantizers=n_quantizers
)
x_timbre = x_timbre.transpose(1, 2)
x_timbre = self.timbre_encoder(x_timbre, None, None)
x_timbre = x_timbre.transpose(1, 2)
spk_embs = torch.mean(x_timbre, dim=2)
return outs, qs, commit_loss, quantized_buf, spk_embs
out = {}
layer_0 = quantized[0]
f0, uv = self.f0_predictor(layer_0)
f0 = rearrange(f0, "... 1 -> ...")
uv = rearrange(uv, "... 1 -> ...")
layer_1 = quantized[1]
(phone,) = self.phone_predictor(layer_1)
out = {"f0": f0, "uv": uv, "phone": phone}
if self.use_gr_prosody_phone:
(prosody_phone,) = self.prosody_phone_predictor(layer_0)
out["prosody_phone"] = prosody_phone
if self.use_gr_content_f0:
content_f0, content_uv = self.content_f0_predictor(layer_1)
content_f0 = rearrange(content_f0, "... 1 -> ...")
content_uv = rearrange(content_uv, "... 1 -> ...")
out["content_f0"] = content_f0
out["content_uv"] = content_uv
if self.vq_num_q_r > 0:
layer_2 = quantized[2]
if self.use_gr_residual_f0:
res_f0, res_uv = self.res_f0_predictor(layer_2)
res_f0 = rearrange(res_f0, "... 1 -> ...")
res_uv = rearrange(res_uv, "... 1 -> ...")
out["res_f0"] = res_f0
out["res_uv"] = res_uv
if self.use_gr_residual_phone:
(res_phone,) = self.res_phone_predictor(layer_2)
out["res_phone"] = res_phone
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
if self.vq_num_q_r > 0:
if self.use_random_mask_residual:
bsz = quantized[2].shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = (
torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
) # (B, 1, 1)
res_mask = res_mask.to(
device=quantized[2].device, dtype=quantized[2].dtype
)
x = (
quantized[0].detach()
+ quantized[1].detach()
+ quantized[2] * res_mask
)
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
else:
x = quantized[0].detach() + quantized[1].detach() + quantized[2]
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
else:
x = quantized[0].detach() + quantized[1].detach()
# x = quantized_perturbe[0].detach() + quantized[1].detach()
if self.use_gr_x_timbre:
(x_timbre,) = self.x_timbre_predictor(x)
out["x_timbre"] = x_timbre
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
out["audio"] = x
return out
def vq2emb(self, vq, use_residual_code=True):
# vq: [num_quantizer, B, T]
self.quantizer = self.quantizer.eval()
out = 0
out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
out += self.quantizer[1].vq2emb(
vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
)
if self.vq_num_q_r > 0 and use_residual_code:
out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
return out
def inference(self, x, speaker_embedding):
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def reset_parameters(self):
self.apply(init_weights)
class FACodecRedecoder(nn.Module):
def __init__(
self,
in_channels=256,
upsample_initial_channel=1280,
up_ratios=(5, 5, 4, 2),
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=256,
codebook_size_prosody=10,
codebook_size_content=10,
codebook_size_residual=10,
):
super().__init__()
self.hop_length = np.prod(up_ratios)
self.up_ratios = up_ratios
self.vq_num_q_p = vq_num_q_p
self.vq_num_q_c = vq_num_q_c
self.vq_num_q_r = vq_num_q_r
self.vq_dim = vq_dim
self.codebook_size_prosody = codebook_size_prosody
self.codebook_size_content = codebook_size_content
self.codebook_size_residual = codebook_size_residual
self.prosody_embs = nn.ModuleList()
for i in range(self.vq_num_q_p):
emb_tokens = nn.Embedding(
num_embeddings=2**self.codebook_size_prosody,
embedding_dim=self.vq_dim,
)
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
self.prosody_embs.append(emb_tokens)
self.content_embs = nn.ModuleList()
for i in range(self.vq_num_q_c):
emb_tokens = nn.Embedding(
num_embeddings=2**self.codebook_size_content,
embedding_dim=self.vq_dim,
)
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
self.content_embs.append(emb_tokens)
self.residual_embs = nn.ModuleList()
for i in range(self.vq_num_q_r):
emb_tokens = nn.Embedding(
num_embeddings=2**self.codebook_size_residual,
embedding_dim=self.vq_dim,
)
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
self.residual_embs.append(emb_tokens)
# Add first conv layer
channels = upsample_initial_channel
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(up_ratios):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
self.timbre_linear.bias.data[:in_channels] = 1
self.timbre_linear.bias.data[in_channels:] = 0
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
self.timbre_cond_prosody_enc = TransformerEncoder(
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_cln=True,
cfg=None,
)
def forward(
self,
vq,
speaker_embedding,
use_residual_code=False,
):
x = 0
x_p = 0
for i in range(self.vq_num_q_p):
x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d)
spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1)
x_p = self.timbre_cond_prosody_enc(
x_p, key_padding_mask=None, condition=spk_cond
)
x = x + x_p
x_c = 0
for i in range(self.vq_num_q_c):
x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i])
x = x + x_c
if use_residual_code:
x_r = 0
for i in range(self.vq_num_q_r):
x_r = x_r + self.residual_embs[i](
vq[self.vq_num_q_p + self.vq_num_q_c + i]
)
x = x + x_r
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
return x
def vq2emb(self, vq, speaker_embedding, use_residual=True):
out = 0
x_t = 0
for i in range(self.vq_num_q_p):
x_t += self.prosody_embs[i](vq[i]) # (B, T, d)
spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1)
x_t = self.timbre_cond_prosody_enc(
x_t, key_padding_mask=None, condition=spk_cond
)
# prosody
out += x_t
# content
for i in range(self.vq_num_q_c):
out += self.content_embs[i](vq[self.vq_num_q_p + i])
# residual
if use_residual:
for i in range(self.vq_num_q_r):
out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
out = out.transpose(1, 2) # (B, T, d) -> (B, d, T)
return out
def inference(self, x, speaker_embedding):
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
return x
class FACodecEncoderV2(nn.Module):
def __init__(
self,
ngf=32,
up_ratios=(2, 4, 5, 5),
out_channels=1024,
):
super().__init__()
self.hop_length = np.prod(up_ratios)
self.up_ratios = up_ratios
# Create first convolution
d_model = ngf
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in up_ratios:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
self.mel_transform = MelSpectrogram(
n_fft=1024,
num_mels=80,
sampling_rate=16000,
hop_size=200,
win_size=800,
fmin=0,
fmax=8000,
)
self.reset_parameters()
def forward(self, x):
out = self.block(x)
return out
def inference(self, x):
return self.block(x)
def get_prosody_feature(self, x):
return self.mel_transform(x.squeeze(1))[:, :20, :]
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, nn.Conv1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def reset_parameters(self):
self.apply(init_weights)
class FACodecDecoderV2(nn.Module):
def __init__(
self,
in_channels=256,
upsample_initial_channel=1536,
ngf=32,
up_ratios=(5, 5, 4, 2),
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=1024,
vq_commit_weight=0.005,
vq_weight_init=False,
vq_full_commit_loss=False,
codebook_dim=8,
codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
codebook_size_content=10,
codebook_size_residual=10,
quantizer_dropout=0.0,
dropout_type="linear",
use_gr_content_f0=False,
use_gr_prosody_phone=False,
use_gr_residual_f0=False,
use_gr_residual_phone=False,
use_gr_x_timbre=False,
use_random_mask_residual=True,
prob_random_mask_residual=0.75,
):
super().__init__()
self.hop_length = np.prod(up_ratios)
self.ngf = ngf
self.up_ratios = up_ratios
self.use_random_mask_residual = use_random_mask_residual
self.prob_random_mask_residual = prob_random_mask_residual
self.vq_num_q_p = vq_num_q_p
self.vq_num_q_c = vq_num_q_c
self.vq_num_q_r = vq_num_q_r
self.codebook_size_prosody = codebook_size_prosody
self.codebook_size_content = codebook_size_content
self.codebook_size_residual = codebook_size_residual
quantizer_class = ResidualVQ
self.quantizer = nn.ModuleList()
# prosody
quantizer = quantizer_class(
num_quantizers=vq_num_q_p,
dim=vq_dim,
codebook_size=codebook_size_prosody,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# phone
quantizer = quantizer_class(
num_quantizers=vq_num_q_c,
dim=vq_dim,
codebook_size=codebook_size_content,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# residual
if self.vq_num_q_r > 0:
quantizer = quantizer_class(
num_quantizers=vq_num_q_r,
dim=vq_dim,
codebook_size=codebook_size_residual,
codebook_dim=codebook_dim,
threshold_ema_dead_code=2,
commitment=vq_commit_weight,
weight_init=vq_weight_init,
full_commit_loss=vq_full_commit_loss,
quantizer_dropout=quantizer_dropout,
dropout_type=dropout_type,
)
self.quantizer.append(quantizer)
# Add first conv layer
channels = upsample_initial_channel
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(up_ratios):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
self.timbre_encoder = TransformerEncoder(
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_cln=False,
)
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
self.timbre_linear.bias.data[:in_channels] = 1
self.timbre_linear.bias.data[in_channels:] = 0
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
self.f0_predictor = CNNLSTM(in_channels, 1, 2)
self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
self.use_gr_content_f0 = use_gr_content_f0
self.use_gr_prosody_phone = use_gr_prosody_phone
self.use_gr_residual_f0 = use_gr_residual_f0
self.use_gr_residual_phone = use_gr_residual_phone
self.use_gr_x_timbre = use_gr_x_timbre
if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
self.res_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
)
if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
self.res_phone_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
)
if self.use_gr_content_f0:
self.content_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
)
if self.use_gr_prosody_phone:
self.prosody_phone_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
)
if self.use_gr_x_timbre:
self.x_timbre_predictor = nn.Sequential(
GradientReversal(alpha=1),
CNNLSTM(in_channels, 245200, 1, global_pred=True),
)
self.melspec_linear = nn.Linear(20, 256)
self.melspec_encoder = TransformerEncoder(
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_cln=False,
cfg=None,
)
self.reset_parameters()
def quantize(self, x, prosody_feature, n_quantizers=None):
outs, qs, commit_loss, quantized_buf = 0, [], [], []
# prosody
f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(f0_input, None, None)
f0_input = f0_input.transpose(1, 2)
f0_quantizer = self.quantizer[0]
out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0))
commit_loss.append(commit)
# phone
phone_input = x
phone_quantizer = self.quantizer[1]
out, q, commit, quantized = phone_quantizer(
phone_input, n_quantizers=n_quantizers
)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0))
commit_loss.append(commit)
# residual
if self.vq_num_q_r > 0:
residual_quantizer = self.quantizer[2]
residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
out, q, commit, quantized = residual_quantizer(
residual_input, n_quantizers=n_quantizers
)
outs += out
qs.append(q)
quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
commit_loss.append(commit)
qs = torch.cat(qs, dim=0)
commit_loss = torch.cat(commit_loss, dim=0)
return outs, qs, commit_loss, quantized_buf
def forward(
self,
x,
prosody_feature,
vq=True,
get_vq=False,
eval_vq=True,
speaker_embedding=None,
n_quantizers=None,
quantized=None,
):
if get_vq:
return self.quantizer.get_emb()
if vq is True:
if eval_vq:
self.quantizer.eval()
x_timbre = x
outs, qs, commit_loss, quantized_buf = self.quantize(
x, prosody_feature, n_quantizers=n_quantizers
)
x_timbre = x_timbre.transpose(1, 2)
x_timbre = self.timbre_encoder(x_timbre, None, None)
x_timbre = x_timbre.transpose(1, 2)
spk_embs = torch.mean(x_timbre, dim=2)
return outs, qs, commit_loss, quantized_buf, spk_embs
out = {}
layer_0 = quantized[0]
f0, uv = self.f0_predictor(layer_0)
f0 = rearrange(f0, "... 1 -> ...")
uv = rearrange(uv, "... 1 -> ...")
layer_1 = quantized[1]
(phone,) = self.phone_predictor(layer_1)
out = {"f0": f0, "uv": uv, "phone": phone}
if self.use_gr_prosody_phone:
(prosody_phone,) = self.prosody_phone_predictor(layer_0)
out["prosody_phone"] = prosody_phone
if self.use_gr_content_f0:
content_f0, content_uv = self.content_f0_predictor(layer_1)
content_f0 = rearrange(content_f0, "... 1 -> ...")
content_uv = rearrange(content_uv, "... 1 -> ...")
out["content_f0"] = content_f0
out["content_uv"] = content_uv
if self.vq_num_q_r > 0:
layer_2 = quantized[2]
if self.use_gr_residual_f0:
res_f0, res_uv = self.res_f0_predictor(layer_2)
res_f0 = rearrange(res_f0, "... 1 -> ...")
res_uv = rearrange(res_uv, "... 1 -> ...")
out["res_f0"] = res_f0
out["res_uv"] = res_uv
if self.use_gr_residual_phone:
(res_phone,) = self.res_phone_predictor(layer_2)
out["res_phone"] = res_phone
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
if self.vq_num_q_r > 0:
if self.use_random_mask_residual:
bsz = quantized[2].shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = (
torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
) # (B, 1, 1)
res_mask = res_mask.to(
device=quantized[2].device, dtype=quantized[2].dtype
)
x = (
quantized[0].detach()
+ quantized[1].detach()
+ quantized[2] * res_mask
)
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
else:
x = quantized[0].detach() + quantized[1].detach() + quantized[2]
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
else:
x = quantized[0].detach() + quantized[1].detach()
# x = quantized_perturbe[0].detach() + quantized[1].detach()
if self.use_gr_x_timbre:
(x_timbre,) = self.x_timbre_predictor(x)
out["x_timbre"] = x_timbre
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
out["audio"] = x
return out
def vq2emb(self, vq, use_residual=True):
# vq: [num_quantizer, B, T]
self.quantizer = self.quantizer.eval()
out = 0
out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
out += self.quantizer[1].vq2emb(
vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
)
if self.vq_num_q_r > 0 and use_residual:
out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
return out
def inference(self, x, speaker_embedding):
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
x = x.transpose(1, 2)
x = self.timbre_norm(x)
x = x.transpose(1, 2)
x = x * gamma + beta
x = self.model(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def reset_parameters(self):
self.apply(init_weights)
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.autograd import Function
import torch
from torch import nn
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -alpha * grad_output
return grad_input, None
revgrad = GradientReversal.apply
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha)
import torch
import pyworld as pw
import numpy as np
import soundfile as sf
import os
from torchaudio.functional import pitch_shift
import librosa
from librosa.filters import mel as librosa_mel_fn
import torch.nn as nn
import torch.nn.functional as F
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
class MelSpectrogram(nn.Module):
def __init__(
self,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
center=False,
):
super(MelSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_size = hop_size
self.win_size = win_size
self.sampling_rate = sampling_rate
self.num_mels = num_mels
self.fmin = fmin
self.fmax = fmax
self.center = center
mel_basis = {}
hann_window = {}
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis = torch.from_numpy(mel).float()
hann_window = torch.hann_window(win_size)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("hann_window", hann_window)
def forward(self, y):
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec)
return spec
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .fvq import *
from .rvq import *
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
class FactorizedVectorQuantize(nn.Module):
def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
if dim != self.codebook_dim:
self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
else:
self.in_proj = nn.Identity()
self.out_proj = nn.Identity()
self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
@property
def codebook(self):
return self._codebook
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# transpose since we use linear
z = rearrange(z, "b d t -> b t d")
# Factorized codes project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x T x D)
z_e = rearrange(z_e, "b t d -> b d t")
z_q, indices = self.decode_latents(z_e)
if self.training:
commitment_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
commit_loss = commitment_loss + codebook_loss
else:
commit_loss = torch.zeros(z.shape[0], device=z.device)
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = rearrange(z_q, "b d t -> b t d")
z_q = self.out_proj(z_q)
z_q = rearrange(z_q, "b t d -> b d t")
return z_q, indices, commit_loss
def vq2emb(self, vq, proj=True):
emb = self.embed_code(vq)
if proj:
emb = self.out_proj(emb)
return emb.transpose(1, 2)
def get_emb(self):
return self.codebook.weight
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import nn
from .fvq import FactorizedVectorQuantize
class ResidualVQ(nn.Module):
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
super().__init__()
VQ = FactorizedVectorQuantize
if type(codebook_size) == int:
codebook_size = [codebook_size] * num_quantizers
self.layers = nn.ModuleList(
[VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
)
self.num_quantizers = num_quantizers
self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
self.dropout_type = kwargs.get("dropout_type", None)
def forward(self, x, n_quantizers=None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
all_quantized = []
if n_quantizers is None:
n_quantizers = self.num_quantizers
if self.training:
n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
if self.dropout_type == "linear":
dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
elif self.dropout_type == "exp":
dropout = torch.randint(
1, int(math.log2(self.num_quantizers)), (x.shape[0],)
)
dropout = torch.pow(2, dropout)
n_dropout = int(x.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(x.device)
for idx, layer in enumerate(self.layers):
if not self.training and idx >= n_quantizers:
break
quantized, indices, loss = layer(residual)
mask = (
torch.full((x.shape[0],), fill_value=idx, device=x.device)
< n_quantizers
)
residual = residual - quantized
quantized_out = quantized_out + quantized * mask[:, None, None]
# loss
loss = (loss * mask).mean()
all_indices.append(indices)
all_losses.append(loss)
all_quantized.append(quantized)
all_losses, all_indices, all_quantized = map(
torch.stack, (all_losses, all_indices, all_quantized)
)
return quantized_out, all_indices, all_losses, all_quantized
def vq2emb(self, vq):
# vq: [n_quantizers, B, T]
quantized_out = 0.0
for idx, layer in enumerate(self.layers):
quantized = layer.vq2emb(vq[idx])
quantized_out += quantized
return quantized_out
def get_emb(self):
embs = []
for idx, layer in enumerate(self.layers):
embs.append(layer.get_emb())
return embs
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn import functional as F
class StyleAdaptiveLayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.in_dim = normalized_shape
self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
self.style = nn.Linear(self.in_dim, self.in_dim * 2)
self.style.bias.data[: self.in_dim] = 1
self.style.bias.data[self.in_dim :] = 0
def forward(self, x, condition):
# x: (B, T, d); condition: (B, T, d)
style = self.style(torch.mean(condition, dim=1, keepdim=True))
gamma, beta = style.chunk(2, -1)
out = self.norm(x)
out = gamma * out + beta
return out
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
super().__init__()
self.dropout = dropout
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0)]
return F.dropout(x, self.dropout, training=self.training)
class TransformerFFNLayer(nn.Module):
def __init__(
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
):
super().__init__()
self.encoder_hidden = encoder_hidden
self.conv_filter_size = conv_filter_size
self.conv_kernel_size = conv_kernel_size
self.encoder_dropout = encoder_dropout
self.ffn_1 = nn.Conv1d(
self.encoder_hidden,
self.conv_filter_size,
self.conv_kernel_size,
padding=self.conv_kernel_size // 2,
)
self.ffn_1.weight.data.normal_(0.0, 0.02)
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
self.ffn_2.weight.data.normal_(0.0, 0.02)
def forward(self, x):
# x: (B, T, d)
x = self.ffn_1(x.permute(0, 2, 1)).permute(
0, 2, 1
) # (B, T, d) -> (B, d, T) -> (B, T, d)
x = F.relu(x)
x = F.dropout(x, self.encoder_dropout, training=self.training)
x = self.ffn_2(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
encoder_hidden,
encoder_head,
conv_filter_size,
conv_kernel_size,
encoder_dropout,
use_cln,
):
super().__init__()
self.encoder_hidden = encoder_hidden
self.encoder_head = encoder_head
self.conv_filter_size = conv_filter_size
self.conv_kernel_size = conv_kernel_size
self.encoder_dropout = encoder_dropout
self.use_cln = use_cln
if not self.use_cln:
self.ln_1 = nn.LayerNorm(self.encoder_hidden)
self.ln_2 = nn.LayerNorm(self.encoder_hidden)
else:
self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
self.self_attn = nn.MultiheadAttention(
self.encoder_hidden, self.encoder_head, batch_first=True
)
self.ffn = TransformerFFNLayer(
self.encoder_hidden,
self.conv_filter_size,
self.conv_kernel_size,
self.encoder_dropout,
)
def forward(self, x, key_padding_mask, conditon=None):
# x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
# self attention
residual = x
if self.use_cln:
x = self.ln_1(x, conditon)
else:
x = self.ln_1(x)
if key_padding_mask != None:
key_padding_mask_input = ~(key_padding_mask.bool())
else:
key_padding_mask_input = None
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
)
x = F.dropout(x, self.encoder_dropout, training=self.training)
x = residual + x
# ffn
residual = x
if self.use_cln:
x = self.ln_2(x, conditon)
else:
x = self.ln_2(x)
x = self.ffn(x)
x = residual + x
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_cln=False,
cfg=None,
):
super().__init__()
self.encoder_layer = (
encoder_layer if encoder_layer is not None else cfg.encoder_layer
)
self.encoder_hidden = (
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
)
self.encoder_head = (
encoder_head if encoder_head is not None else cfg.encoder_head
)
self.conv_filter_size = (
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
)
self.conv_kernel_size = (
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
)
self.encoder_dropout = (
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
)
self.use_cln = use_cln if use_cln is not None else cfg.use_cln
if enc_emb_tokens != None:
self.use_enc_emb = True
self.enc_emb_tokens = enc_emb_tokens
else:
self.use_enc_emb = False
self.position_emb = PositionalEncoding(
self.encoder_hidden, self.encoder_dropout
)
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerEncoderLayer(
self.encoder_hidden,
self.encoder_head,
self.conv_filter_size,
self.conv_kernel_size,
self.encoder_dropout,
self.use_cln,
)
for i in range(self.encoder_layer)
]
)
if self.use_cln:
self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
else:
self.last_ln = nn.LayerNorm(self.encoder_hidden)
def forward(self, x, key_padding_mask, condition=None):
if len(x.shape) == 2 and self.use_enc_emb:
x = self.enc_emb_tokens(x)
x = self.position_emb(x)
else:
x = self.position_emb(x) # (B, T, d)
for layer in self.layers:
x = layer(x, key_padding_mask, condition)
if self.use_cln:
x = self.last_ln(x, condition)
else:
x = self.last_ln(x)
return x
# Copyright (c) 2023 Amphion.
#
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
# Licensed under Apache License 2.0
from .modules.seanet import SEANetEncoder, SEANetDecoder
from .modules.quantization import ResidualVectorQuantizer
import torch.nn as nn
from einops import rearrange
import torch
import numpy as np
class SpeechTokenizer(nn.Module):
def __init__(self, config):
"""
Parameters
----------
config : json
Model Config.
"""
super().__init__()
self.encoder = SEANetEncoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=config.get("bidirectional"),
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
self.sample_rate = config.get("sample_rate")
self.n_q = config.get("n_q")
self.downsample_rate = np.prod(config.get("strides"))
if config.get("dimension") != config.get("semantic_dimension"):
self.transform = nn.Linear(
config.get("dimension"), config.get("semantic_dimension")
)
else:
self.transform = nn.Identity()
self.quantizer = ResidualVectorQuantizer(
dimension=config.get("dimension"),
n_q=config.get("n_q"),
bins=config.get("codebook_size"),
)
self.decoder = SEANetDecoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=False,
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
"""
Parameters
----------
config_path : str
Path of model configuration file.
ckpt_path : str
Path of model checkpoint.
Returns
-------
model : SpeechTokenizer
SpeechTokenizer model.
"""
import json
with open(config_path) as f:
cfg = json.load(f)
model = cls(cfg)
params = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(params)
return model
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
layers : list[int], optional
Layers of RVQ should return quantized result. The default is the first layer.
Returns
-------
o : torch.tensor
Output wavs. Shape: (batch, channels, timesteps).
commit_loss : torch.tensor
Commitment loss from residual vector quantizers.
feature : torch.tensor
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
"""
n_q = n_q if n_q else self.n_q
e = self.encoder(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(
e, n_q=n_q, layers=layers
)
feature = rearrange(quantized_list[0], "b d t -> b t d")
feature = self.transform(feature)
o = self.decoder(quantized)
return o, commit_loss, feature
def forward_feature(self, x: torch.tensor, layers: list = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape should be (batch, channels, timesteps).
layers : list[int], optional
Layers of RVQ should return quantized result. The default is all layers.
Returns
-------
quantized_list : list[torch.tensor]
Quantized of required layers.
"""
e = self.encoder(x)
layers = layers if layers else list(range(self.n_q))
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
return quantized_list
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
codes : torch.tensor
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
"""
e = self.encoder(x)
if st is None:
st = 0
n_q = n_q if n_q else self.n_q
codes = self.quantizer.encode(e, n_q=n_q, st=st)
return codes
def decode(self, codes: torch.tensor, st: int = 0):
"""
Parameters
----------
codes : torch.tensor
Indices for each quantizer. Shape: (n_q, batch, timesteps).
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
o : torch.tensor
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
"""
quantized = self.quantizer.decode(codes, st=st)
o = self.decoder(quantized)
return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment