"tests/models/auto/test_modeling_tf_auto.py" did not exist on "c014d1f0c64a318ef288ab8dca47aa3f29052540"
Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
import torch
from torch import nn
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
from inspiremusic.wavtokenizer.decoder.spectral_ops import IMDCT, ISTFT
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class ISTFTHead(FourierHead):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
self.out = nn.Linear(dim, out_dim)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
self.clip_audio = clip_audio
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.out.weight.mul_(scale.view(-1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
x = symexp(x)
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class IMDCTCosHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
super().__init__()
self.clip_audio = clip_audio
self.out = nn.Linear(dim, mdct_frame_len)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
m, p = x.chunk(2, dim=2)
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
audio = self.imdct(m * torch.cos(p))
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from pytorch_lightning import Callback
matplotlib.use("Agg")
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
"""
Save a matplotlib figure to a numpy array.
Args:
fig (Figure): Matplotlib figure object.
Returns:
ndarray: Numpy array representing the figure.
"""
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
"""
Plot a spectrogram and convert it to a numpy array.
Args:
spectrogram (ndarray): Spectrogram data.
Returns:
ndarray: Numpy array representing the plotted spectrogram.
"""
spectrogram = spectrogram.astype(np.float32)
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
class GradNormCallback(Callback):
"""
Callback to log the gradient norm.
"""
def on_after_backward(self, trainer, model):
model.log("grad_norm", gradient_norm(model))
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
"""
Compute the gradient norm.
Args:
model (Module): PyTorch model.
norm_type (float, optional): Type of the norm. Defaults to 2.0.
Returns:
Tensor: Gradient norm.
"""
grads = [p.grad for p in model.parameters() if p.grad is not None]
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
return total_norm
from typing import List, Tuple
import torch
import torchaudio
from torch import nn
from decoder.modules import safe_log
import torch.nn.functional as F
class MelSpecReconstructionLoss(nn.Module):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def __init__(
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
)
def forward(self, y_hat, y) -> torch.Tensor:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat = safe_log(self.mel_spec(y_hat))
mel = safe_log(self.mel_spec(y))
loss = torch.nn.functional.l1_loss(mel, mel_hat)
return loss
class GeneratorLoss(nn.Module):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
gen_losses.append(l)
loss += l
return loss, gen_losses
class DiscriminatorLoss(nn.Module):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def forward(
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class FeatureMatchingLoss(nn.Module):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss
class DACGANLoss(nn.Module):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
# d_fake = self.discriminator(fake.audio_data)
# d_real = self.discriminator(real.audio_data)
d_fake = self.discriminator(fake)
d_real = self.discriminator(real)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += torch.mean(x_fake[-1] ** 2)
loss_d += torch.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature
from typing import Optional
import torch
from torch import nn
from torch.nn.utils import weight_norm
from inspiremusic.wavtokenizer.decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv1d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h = q.shape
q = q.permute(0, 2, 1) # b,hw,c
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
self.temb_ch = 0
block_in = dim
dropout = 0.1
attn_type="vanilla"
pos_net : tp.List[nn.Module] = [
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
make_attn(block_in, attn_type=attn_type),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
Normalize(block_in)
]
self.pos_net = nn.Sequential(*pos_net)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(x)
x = self.pos_net(x)
if self.adanorm:
# assert bandwidth_id is not None
if bandwidth_id is None:
bandwidth_id = torch.tensor(0, device='cuda')
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x
from typing import Optional
from typing import Tuple
import torch
from torch import nn
from torch.nn.utils import weight_norm, remove_weight_norm
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, ...] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: float = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
]
)
self.gamma = nn.ParameterList(
[
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
import os
from typing import Tuple, Any, Union, Dict
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
from inspiremusic.wavtokenizer.decoder.heads import FourierHead
from inspiremusic.wavtokenizer.decoder.models import Backbone
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class WavTokenizer(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in model.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_hparams_feat(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained_feat(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def estimator(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_pretrained0911(self, config_path, model_folder_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0802(config_path)
models = os.listdir(model_folder_path)
val_loss = []
for item in models:
if not item.startswith('vocos_'):
continue
val_loss.append(item[-11:-5])
val_loss.sort()
val_loss = val_loss[:3] # 取前3性能较好的模型平均
state_dict = dict()
state_dicts = []
for item in models:
if not item.startswith('vocos_'):
continue
ll = item[-11:-5]
if ll not in val_loss:
continue
model_path = model_folder_path + '/' + item
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict_single = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict_single[k] = v
state_dicts.append(state_dict_single)
for kk in state_dicts[0].keys():
vv = state_dicts[0][kk]
for i in range(1, len(state_dicts)):
ss = state_dicts[i]
vv += ss[kk]
vm = vv/len(state_dicts)
state_dict[kk] = vm
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
audio_output = self.decode(features, **kwargs)
return audio_output
# 0818
@torch.inference_mode()
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
return features,discrete_codes
# 0818
@torch.inference_mode()
def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
return features,discrete_codes
@torch.inference_mode()
def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
_, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs)
discrete_codes = discrete_codes.clamp(min=0, max=16383)
return discrete_codes
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
features = features.transpose(1, 2)
return features
from typing import Tuple, Any, Union, Dict
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
from inspiremusic.wavtokenizer.decoder.heads import FourierHead
from inspiremusic.wavtokenizer.decoder.models import Backbone
from inspiremusic.wavtokenizer.decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class WavTokenizer(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
self.multiperioddisc = multiperioddisc
self.multiresddisc = multiresddisc
@classmethod
def from_hparams0828(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head,
multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4),
multiresddisc=MultiResolutionDiscriminator(num_embeddings=4))
return model
@classmethod
def from_pretrained0828(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0828(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \
or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'):
state_dict[k] = v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
return model
@classmethod
def from_hparams0802(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained0802(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0802(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
audio_output = self.decode(features, **kwargs)
return audio_output
# 0818
@torch.inference_mode()
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, _, _ = self.feature_extractor(audio_input, **kwargs)
return features
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = features.transpose(1, 2)
return features
import numpy as np
import scipy
import torch
from torch import nn, view_as_real, view_as_complex
import pdb
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
# assert (window_envelope > 1e-11).all()
if not torch.all(window_envelope > 1e-11):
window_envelope = torch.clamp(window_envelope, min=1e-11)
y = y / window_envelope
return y
def onnx_forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
pdb.set_trace()
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
# assert (window_envelope > 1e-11).all()
if not torch.all(window_envelope > 1e-11):
window_envelope = torch.clamp(window_envelope, min=1e-11)
y = y / window_envelope
return y
class MDCT(nn.Module):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if self.padding == "center":
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
elif self.padding == "same":
# hop_length is 1/2 frame_len
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
else:
raise ValueError("Padding must be 'center' or 'same'.")
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
N = self.frame_len // 2
x = x * self.window.expand(x.shape)
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
return torch.real(res) * np.sqrt(2)
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""EnCodec neural audio codec."""
__version__ = "0.1.2a3"
from .model import EncodecModel
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one.")
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(
buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""EnCodec model implementation."""
import math
from pathlib import Path
import typing as tp
import numpy as np
import torch
from torch import nn
from . import quantization as qt
from . import modules as m
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
class LMModel(nn.Module):
"""Language Model to estimate probabilities of each codebook entry.
We predict all codebooks in parallel for a given time step.
Args:
n_q (int): number of codebooks.
card (int): codebook cardinality.
dim (int): transformer dimension.
**kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
"""
def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
self.dim = dim
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
def forward(self, indices: torch.Tensor,
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
"""
Args:
indices (torch.Tensor): indices from the previous time step. Indices
should be 1 + actual index in the codebook. The value 0 is reserved for
when the index is missing (i.e. first time step). Shape should be
`[B, n_q, T]`.
states: state for the streaming decoding.
offset: offset of the current time step.
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
with a shape `[B, card, n_q, T]`.
"""
B, K, T = indices.shape
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
return torch.softmax(logits, dim=1), states, offset
class EncodecModel(nn.Module):
"""EnCodec model operating on the raw waveform.
Args:
target_bandwidths (list of float): Target bandwidths.
encoder (nn.Module): Encoder network.
decoder (nn.Module): Decoder network.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
normalize (bool): Whether to apply audio normalization.
segment (float or None): segment duration in sec. when doing overlap-add.
overlap (float): overlap between segment, given as a fraction of the segment duration.
name (str): name of the model, used as metadata when compressing audio.
"""
def __init__(self,
encoder: m.SEANetEncoder,
decoder: m.SEANetDecoder,
quantizer: qt.ResidualVectorQuantizer,
target_bandwidths: tp.List[float],
sample_rate: int,
channels: int,
normalize: bool = False,
segment: tp.Optional[float] = None,
overlap: float = 0.01,
name: str = 'unset'):
super().__init__()
self.bandwidth: tp.Optional[float] = None
self.target_bandwidths = target_bandwidths
self.encoder = encoder
self.quantizer = quantizer
self.decoder = decoder
self.sample_rate = sample_rate
self.channels = channels
self.normalize = normalize
self.segment = segment
self.overlap = overlap
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
self.name = name
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
"quantizer bins must be a power of 2."
@property
def segment_length(self) -> tp.Optional[int]:
if self.segment is None:
return None
return int(self.segment * self.sample_rate)
@property
def segment_stride(self) -> tp.Optional[int]:
segment_length = self.segment_length
if segment_length is None:
return None
return max(1, int((1 - self.overlap) * segment_length))
def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
"""Given a tensor `x`, returns a list of frames containing
the discrete encoded codes for `x`, along with rescaling factors
for each segment, when `self.normalize` is True.
Each frames is a tuple `(codebook, scale)`, with `codebook` of
shape `[B, K, T]`, with `K` the number of codebooks.
"""
assert x.dim() == 3
_, channels, length = x.shape
assert channels > 0 and channels <= 2
segment_length = self.segment_length
if segment_length is None:
segment_length = length
stride = length
else:
stride = self.segment_stride # type: ignore
assert stride is not None
encoded_frames: tp.List[EncodedFrame] = []
for offset in range(0, length, stride):
frame = x[:, :, offset: offset + segment_length]
encoded_frames.append(self._encode_frame(frame))
return encoded_frames
def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
length = x.shape[-1]
duration = length / self.sample_rate
assert self.segment is None or duration <= 1e-5 + self.segment
if self.normalize:
mono = x.mean(dim=1, keepdim=True)
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
scale = 1e-8 + volume
x = x / scale
scale = scale.view(-1, 1)
else:
scale = None
emb = self.encoder(x)
codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
return codes, scale
def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
"""Decode the given frames into a waveform.
Note that the output might be a bit bigger than the input. In that case,
any extra steps at the end can be trimmed.
"""
segment_length = self.segment_length
if segment_length is None:
assert len(encoded_frames) == 1
return self._decode_frame(encoded_frames[0])
frames = [self._decode_frame(frame) for frame in encoded_frames]
return _linear_overlap_add(frames, self.segment_stride or 1)
def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
codes, scale = encoded_frame
codes = codes.transpose(0, 1)
emb = self.quantizer.decode(codes)
out = self.decoder(emb)
if scale is not None:
out = out * scale.view(-1, 1, 1)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
frames = self.encode(x)
return self.decode(frames)[:, :, :x.shape[-1]]
def set_target_bandwidth(self, bandwidth: float):
if bandwidth not in self.target_bandwidths:
raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.target_bandwidths}.")
self.bandwidth = bandwidth
def get_lm_model(self) -> LMModel:
"""Return the associated LM model to improve the compression rate.
"""
device = next(self.parameters()).device
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
past_context=int(3.5 * self.frame_rate)).to(device)
checkpoints = {
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
}
try:
checkpoint_name = checkpoints[self.name]
except KeyError:
raise RuntimeError("No LM pre-trained for the current Encodec model.")
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
state = torch.hub.load_state_dict_from_url(
url, map_location='cpu', check_hash=True) # type: ignore
lm.load_state_dict(state)
lm.eval()
return lm
@staticmethod
def _get_model(target_bandwidths: tp.List[float],
sample_rate: int = 24_000,
channels: int = 1,
causal: bool = True,
model_norm: str = 'weight_norm',
audio_normalize: bool = False,
segment: tp.Optional[float] = None,
name: str = 'unset'):
encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
quantizer = qt.ResidualVectorQuantizer(
dimension=encoder.dimension,
n_q=n_q,
bins=1024,
)
model = EncodecModel(
encoder,
decoder,
quantizer,
target_bandwidths,
sample_rate,
channels,
normalize=audio_normalize,
segment=segment,
name=name,
)
return model
@staticmethod
def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
if repository is not None:
if not repository.is_dir():
raise ValueError(f"{repository} must exist and be a directory.")
file = repository / checkpoint_name
checksum = file.stem.split('-')[1]
_check_checksum(file, checksum)
return torch.load(file)
else:
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
@staticmethod
def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
"""Return the pretrained causal 24khz model.
"""
if repository:
assert pretrained
target_bandwidths = [1.5, 3., 6, 12., 24.]
checkpoint_name = 'encodec_24khz-d7cc33bc.th'
sample_rate = 24_000
channels = 1
model = EncodecModel._get_model(
target_bandwidths, sample_rate, channels,
causal=True, model_norm='weight_norm', audio_normalize=False,
name='encodec_24khz' if pretrained else 'unset')
if pretrained:
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
model.load_state_dict(state_dict)
model.eval()
return model
@staticmethod
def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
"""Return the pretrained 48khz model.
"""
if repository:
assert pretrained
target_bandwidths = [3., 6., 12., 24.]
checkpoint_name = 'encodec_48khz-7e698e3e.th'
sample_rate = 48_000
channels = 2
model = EncodecModel._get_model(
target_bandwidths, sample_rate, channels,
causal=False, model_norm='time_group_norm', audio_normalize=True,
segment=1., name='encodec_48khz' if pretrained else 'unset')
if pretrained:
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
model.load_state_dict(state_dict)
model.eval()
return model
def test():
from itertools import product
import torchaudio
bandwidths = [3, 6, 12, 24]
models = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz
}
for model_name, bw in product(models.keys(), bandwidths):
model = models[model_name]()
model.set_target_bandwidth(bw)
audio_suffix = model_name.split('_')[1][:3]
wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
wav = wav[:, :model.sample_rate * 2]
wav_in = wav.unsqueeze(0)
wav_dec = model(wav_in)[0]
assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
if __name__ == '__main__':
test()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from .conv import (
pad1d,
unpad1d,
NormConv1d,
NormConvTranspose1d,
NormConv2d,
NormConvTranspose2d,
SConv1d,
SConvTranspose1d,
)
from .lstm import SLSTM
from .seanet import SEANetEncoder, SEANetDecoder
from .transformer import StreamingTransformerEncoder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment