"sgl-kernel/pyproject_cpu.toml" did not exist on "8abf74e3c9353c2c33c83d156d5a69acf6274b72"
Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
from dataclasses import dataclass
import numpy as np
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
import soundfile
# import librosa
import random
torch.set_num_threads(1)
@dataclass
class DataConfig:
filelist_path: str
sampling_rate: int
num_samples: int
batch_size: int
num_workers: int
def collate_fn(batch):
batch = [item for item in batch if item is not None]
return torch.stack(batch, dim=0)
class VocosDataModule(LightningDataModule):
def __init__(self, train_params: DataConfig, val_params: DataConfig):
super().__init__()
self.train_config = train_params
self.val_config = val_params
def _get_dataloder(self, cfg: DataConfig, train: bool):
dataset = VocosDataset(cfg, train=train)
dataloader = DataLoader(
dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, collate_fn=collate_fn
)
return dataloader
def train_dataloader(self) -> DataLoader:
return self._get_dataloder(self.train_config, train=True)
def val_dataloader(self) -> DataLoader:
return self._get_dataloder(self.val_config, train=False)
class VocosDataset(Dataset):
def __init__(self, cfg: DataConfig, train: bool):
with open(cfg.filelist_path) as f:
self.filelist = f.read().splitlines()
self.sampling_rate = cfg.sampling_rate
self.num_samples = cfg.num_samples
self.train = train
def __len__(self) -> int:
return len(self.filelist)
def __getitem__(self, index: int) -> torch.Tensor:
audio_path = self.filelist[index]
# y, sr = torchaudio.load(audio_path)
# print(audio_path,"111")
try:
y1, sr = soundfile.read(audio_path)
# y1, sr = librosa.load(audio_path,sr=None)
y = torch.tensor(y1).float().unsqueeze(0)
# if y.size(0) > 1:
# # mix to mono
# y = y.mean(dim=0, keepdim=True)
if y.ndim > 2:
# mix to mono
# print("有问题哈,数据处理部分")
# y = y.mean(dim=-1, keepdim=False)
random_channel = random.randint(0, y.size(-1) - 1)
y = y[:, :, random_channel]
gain = np.random.uniform(-1, -6) if self.train else -3
y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
if sr != self.sampling_rate:
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
if y.size(-1) < self.num_samples:
pad_length = self.num_samples - y.size(-1)
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
elif self.train:
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
y = y[:, start : start + self.num_samples]
else:
# During validation, take always the first segment for determinism
y = y[:, : self.num_samples]
return y[0]
except Exception as e:
print(f"Error processing file {audio_path} at index {index}: {e}")
# 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据
return None
# def __getitem__(self, index: int) -> torch.Tensor:
# audio_path = self.filelist[index]
# try:
# y, sr = torchaudio.load(audio_path)
# if y.size(0) > 1:
# # 随机选择一个通道
# random_channel = random.randint(0, y.size(0) - 1)
# y = y[random_channel, :].unsqueeze(0) # 保持返回值为 (1, T) 的形式
# # gain = np.random.uniform(-1, -6) if self.train else -3
# # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
# if sr != self.sampling_rate:
# y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
# if y.size(-1) < self.num_samples:
# pad_length = self.num_samples - y.size(-1)
# padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
# y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
# elif self.train:
# start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
# y = y[:, start: start + self.num_samples]
# else:
# # During validation, take always the first segment for determinism
# y = y[:, :self.num_samples]
# return y[0]
# except Exception as e:
# print(f"Error processing file {audio_path} at index {index}: {e}")
# # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据
# return None
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
# from audiotools import AudioSignal
# from audiotools import ml
# from audiotools import STFTParams
from einops import rearrange
from torch.nn.utils import weight_norm
from collections import namedtuple
STFTParams = namedtuple(
"STFTParams",
["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
)
STFTParams.__new__.__defaults__ = (None, None, None, None, None)
def WNConv1d(*args, **kwargs):
act = kwargs.pop("act", True)
conv = weight_norm(nn.Conv1d(*args, **kwargs))
if not act:
return conv
return nn.Sequential(conv, nn.LeakyReLU(0.1))
def WNConv2d(*args, **kwargs):
act = kwargs.pop("act", True)
conv = weight_norm(nn.Conv2d(*args, **kwargs))
if not act:
return conv
return nn.Sequential(conv, nn.LeakyReLU(0.1))
class MPD(nn.Module):
def __init__(self, period):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
]
)
self.conv_post = WNConv2d(
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
)
def pad_to_period(self, x):
t = x.shape[-1]
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
return x
def forward(self, x):
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class MSD(nn.Module):
def __init__(self, rate: int = 1, sample_rate: int = 48000):
super().__init__()
self.convs = nn.ModuleList(
[
WNConv1d(1, 16, 15, 1, padding=7),
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
WNConv1d(1024, 1024, 5, 1, padding=2),
]
)
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
self.sample_rate = sample_rate
self.rate = rate
def forward(self, x):
# x = AudioSignal(x, self.sample_rate)
# x.resample(self.sample_rate // self.rate)
# x = x.audio_data
fmap = []
for l in self.convs:
x = l(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
class MRD(nn.Module):
def __init__(
self,
window_length: int,
hop_factor: float = 0.25,
sample_rate: int = 24000,
bands: list = BANDS,
):
"""Complex multi-band spectrogram discriminator.
Parameters
----------
window_length : int
Window length of STFT.
hop_factor : float, optional
Hop factor of the STFT, defaults to ``0.25 * window_length``.
sample_rate : int, optional
Sampling rate of audio in Hz, by default 24000
bands : list, optional
Bands to run discriminator over.
"""
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.sample_rate = sample_rate
self.stft_params = STFTParams(
window_length=window_length,
hop_length=int(window_length * hop_factor),
match_stride=True,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
self.n_fft = window_length
ch = 32
convs = lambda: nn.ModuleList(
[
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
def spectrogram(self, x):
# x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
# x = torch.view_as_real(x.stft())
# x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size()
# breakpoint()
if x.size(0)==1:
# x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0))
x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0))
else:
# x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1))
x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1))
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for layer in stack:
band = layer(band)
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
x = self.conv_post(x)
fmap.append(x)
return fmap
# class DACDiscriminator(ml.BaseModel):
class DACDiscriminator(nn.Module):
def __init__(
self,
rates: list = [],
periods: list = [2, 3, 5, 7, 11],
fft_sizes: list = [2048, 1024, 512],
sample_rate: int = 24000,
bands: list = BANDS,
):
"""Discriminator that combines multiple discriminators.
Parameters
----------
rates : list, optional
sampling rates (in Hz) to run MSD at, by default []
If empty, MSD is not used.
periods : list, optional
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
fft_sizes : list, optional
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
sample_rate : int, optional
Sampling rate of audio in Hz, by default 24000
bands : list, optional
Bands to run MRD at, by default `BANDS`
"""
super().__init__()
discs = []
discs += [MPD(p) for p in periods]
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
self.discriminators = nn.ModuleList(discs)
def preprocess(self, y):
# Remove DC offset
y = y - y.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
return y
def forward(self, x):
x = self.preprocess(x)
fmaps = [d(x) for d in self.discriminators]
return fmaps
if __name__ == "__main__":
disc = DACDiscriminator()
x = torch.zeros(1, 1, 24000)
results = disc(x)
breakpoint()
for i, result in enumerate(results):
print(f"disc{i}")
for i, r in enumerate(result):
print(r.shape, r.mean(), r.min(), r.max())
print("00")
from typing import Tuple, List
import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm
class MultiPeriodDiscriminator(nn.Module):
"""
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
periods (tuple[int]): Tuple of periods for each discriminator.
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(nn.Module):
def __init__(
self,
period: int,
in_channels: int = 1,
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: int = None,
):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
self.lrelu_slope = lrelu_slope
def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for i, l in enumerate(self.convs):
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
if i > 0:
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
num_embeddings: int = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
Each resolution should be a tuple of (n_fft, hop_length, win_length).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
lrelu_slope: float = 0.1,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.lrelu_slope = lrelu_slope
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
magnitude_spectrogram = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=None, # interestingly rectangular window kind of works here
center=True,
return_complex=True,
).abs()
return magnitude_spectrogram
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torchaudio
import transformers
import yaml
from decoder.discriminator_dac import DACDiscriminator
from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from decoder.feature_extractors import FeatureExtractor
from decoder.heads import FourierHead
from decoder.helpers import plot_spectrogram_to_numpy
from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss
from decoder.models import Backbone
from decoder.modules import safe_log
from decoder.pretrained_model import instantiate_class
class VocosExp(pl.LightningModule):
# noinspection PyUnusedLocal
def __init__(
self,
feature_extractor: FeatureExtractor,
backbone: Backbone,
head: FourierHead,
resume_config: str,
resume_model: str,
sample_rate: int = 24000,
initial_learning_rate: float = 2e-4,
num_warmup_steps: int = 0,
mel_loss_coeff: float = 45,
mrd_loss_coeff: float = 1.0,
pretrain_mel_steps: int = 0,
decay_mel_coeff: bool = False,
evaluate_utmos: bool = False,
evaluate_pesq: bool = False,
evaluate_periodicty: bool = False,
resume: bool = False,
):
"""
Args:
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
backbone (Backbone): An instance of Backbone model.
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
sample_rate (int): Sampling rate of the audio signals.
initial_learning_rate (float): Initial learning rate for the optimizer.
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
"""
super().__init__()
self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
self.resume_config = resume_config
self.resume_model = resume_model
self.resume = resume
self.multiperioddisc = MultiPeriodDiscriminator()
self.multiresddisc = MultiResolutionDiscriminator()
self.dac = DACDiscriminator()
self.dacdiscriminator = DACGANLoss(self.dac)
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
self.train_discriminator = False
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
def configure_optimizers(self):
disc_params = [
{"params": self.multiperioddisc.parameters()},
{"params": self.multiresddisc.parameters()},
{"params": self.dac.parameters()},
]
gen_params = [
{"params": self.feature_extractor.parameters()},
{"params": self.backbone.parameters()},
{"params": self.head.parameters()},
]
opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
)
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
)
return (
[opt_disc, opt_gen],
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
)
def forward(self, audio_input, **kwargs):
features, _, commit_loss = self.feature_extractor(audio_input, **kwargs)
# print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
x = self.backbone(features, **kwargs)
audio_output = self.head(x)
return audio_output, commit_loss
def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
audio_input = batch
# train discriminator
if optimizer_idx == 0 and self.train_discriminator:
with torch.no_grad():
audio_hat, _ = self(audio_input, **kwargs)
loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
loss_mp, loss_mp_real, _ = self.disc_loss(
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
)
loss_mrd, loss_mrd_real, _ = self.disc_loss(
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
)
loss_mp /= len(loss_mp_real)
loss_mrd /= len(loss_mrd_real)
loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac
self.log("discriminator/total", loss, prog_bar=True)
self.log("discriminator/multi_period_loss", loss_mp)
self.log("discriminator/multi_res_loss", loss_mrd)
self.log("discriminator/dac", loss_dac)
return loss
# train generator
if optimizer_idx == 1:
audio_hat, commit_loss = self(audio_input, **kwargs)
if self.train_discriminator:
loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
y=audio_input, y_hat=audio_hat, **kwargs,
)
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
y=audio_input, y_hat=audio_hat, **kwargs,
)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
self.log("generator/multi_period_loss", loss_gen_mp)
self.log("generator/multi_res_loss", loss_gen_mrd)
self.log("generator/feature_matching_mp", loss_fm_mp)
self.log("generator/feature_matching_mrd", loss_fm_mrd)
self.log("generator/loss_dac_1", loss_dac_1)
self.log("generator/loss_dac_2", loss_dac_2)
else:
loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
mel_loss = self.melspec_loss(audio_hat, audio_input)
loss = (
loss_gen_mp
+ self.hparams.mrd_loss_coeff * loss_gen_mrd
+ loss_fm_mp
+ self.hparams.mrd_loss_coeff * loss_fm_mrd
+ self.mel_loss_coeff * mel_loss
+ 1000 * commit_loss
+ loss_dac_1
+ loss_dac_2
)
self.log("generator/total_loss", loss, prog_bar=True)
self.log("mel_loss_coeff", self.mel_loss_coeff)
self.log("generator/mel_loss", mel_loss)
self.log("commit_loss", commit_loss)
if self.global_step % 1000 == 0 and self.global_rank == 0:
self.logger.experiment.add_audio(
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
)
self.logger.experiment.add_audio(
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
)
with torch.no_grad():
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
self.logger.experiment.add_image(
"train/mel_target",
plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
self.logger.experiment.add_image(
"train/mel_pred",
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
return loss
def on_validation_epoch_start(self):
if self.hparams.evaluate_utmos:
from metrics.UTMOS import UTMOSScore
if not hasattr(self, "utmos_model"):
self.utmos_model = UTMOSScore(device=self.device)
def validation_step(self, batch, batch_idx, **kwargs):
audio_input = batch
audio_hat, commit_loss = self(audio_input, **kwargs)
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
if self.hparams.evaluate_periodicty:
from metrics.periodicity import calculate_periodicity_metrics
periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
else:
periodicity_loss = pitch_loss = f1_score = 0
if self.hparams.evaluate_utmos:
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
else:
utmos_score = torch.zeros(1, device=self.device)
if self.hparams.evaluate_pesq:
from pesq import pesq
pesq_score = 0
for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
pesq_score /= len(audio_16_khz)
pesq_score = torch.tensor(pesq_score)
else:
pesq_score = torch.zeros(1, device=self.device)
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss
return {
"val_loss": total_loss,
"mel_loss": mel_loss,
"utmos_score": utmos_score,
"pesq_score": pesq_score,
"periodicity_loss": periodicity_loss,
"pitch_loss": pitch_loss,
"f1_score": f1_score,
"audio_input": audio_input[0],
"audio_pred": audio_hat[0],
}
def validation_epoch_end(self, outputs):
if self.global_rank == 0:
*_, audio_in, audio_pred = outputs[0].values()
self.logger.experiment.add_audio(
"val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
)
self.logger.experiment.add_audio(
"val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
)
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
self.logger.experiment.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
self.logger.experiment.add_image(
"val_mel_hat",
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
f1_score = np.array([x["f1_score"] for x in outputs]).mean()
self.log("val_loss", avg_loss, sync_dist=True)
self.log("val/mel_loss", mel_loss, sync_dist=True)
self.log("val/utmos_score", utmos_score, sync_dist=True)
self.log("val/pesq_score", pesq_score, sync_dist=True)
self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
self.log("val/pitch_loss", pitch_loss, sync_dist=True)
self.log("val/f1_score", f1_score, sync_dist=True)
@property
def global_step(self):
"""
Override global_step so that it returns the total number of batches processed
"""
return self.trainer.fit_loop.epoch_loop.total_batch_idx
def on_train_batch_start(self, *args):
if self.global_step >= self.hparams.pretrain_mel_steps:
self.train_discriminator = True
else:
self.train_discriminator = False
def on_train_batch_end(self, *args):
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
max_steps = self.trainer.max_steps // 2
if current_step < self.hparams.num_warmup_steps:
return 1.0
progress = float(current_step - self.hparams.num_warmup_steps) / float(
max(1, max_steps - self.hparams.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
if self.hparams.decay_mel_coeff:
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
class WavTokenizer(VocosExp):
"""
WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
while during validation, a fixed bandwidth_id is used.
"""
def __init__(
self,
feature_extractor: FeatureExtractor,
backbone: Backbone,
head: FourierHead,
resume_config: str,
resume_model: str,
sample_rate: int = 24000,
initial_learning_rate: float = 2e-4,
num_warmup_steps: int = 0,
mel_loss_coeff: float = 45,
mrd_loss_coeff: float = 1.0,
pretrain_mel_steps: int = 0,
decay_mel_coeff: bool = False,
evaluate_utmos: bool = False,
evaluate_pesq: bool = False,
evaluate_periodicty: bool = False,
resume: bool = False,
):
super().__init__(
feature_extractor,
backbone,
head,
resume_config,
resume_model,
sample_rate,
initial_learning_rate,
num_warmup_steps,
mel_loss_coeff,
mrd_loss_coeff,
pretrain_mel_steps,
decay_mel_coeff,
evaluate_utmos,
evaluate_pesq,
evaluate_periodicty,
resume
)
# Override with conditional discriminators
# VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
# if self.resume:
# VocosExp.load_from_checkpoint(self.resume_model)
self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
self.dac = DACDiscriminator()
if self.resume:
print('加载预训练模型:', self.resume_model)
# with open(self.resume_config, "r") as f:
# config = yaml.safe_load(f)
# feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
# backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
# head = instantiate_class(args=(), init=config['model']['init_args']["head"])
# 不加载量化器部分权重
state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict']
state_dict_fa_qa = dict()
state_dict_fa_en = dict()
state_dict_fa_de = dict()
state_dict_bb = dict()
state_dict_hd = dict()
state_dict_mp = dict()
state_dict_mr = dict()
state_dict_dac = dict()
for k, v in state_dict_raw.items():
# breakpoint()
if k.startswith('feature_extractor.encodec.quantizer'):
# breakpoint()
# print("*****",k)
ss = k[46:48]
if ss[-1] == '.':
num = int(ss[0])
# print("num,k",num,k[36:])
if num <= 7:
state_dict_fa_qa[k[36:]] = v
if k.startswith('feature_extractor.encodec.encoder'):
state_dict_fa_en[k[34:]] = v
if k.startswith('feature_extractor.encodec.decoder'):
state_dict_fa_de[k[34:]] = v
if k.startswith('backbone.'):
state_dict_bb[k[9:]] = v
if k.startswith('head.'):
state_dict_hd[k[5:]] = v
if k.startswith('multiperioddisc.'):
state_dict_mp[k[16:]] = v
if k.startswith('multiresddisc.'):
state_dict_mr[k[14:]] = v
if k.startswith('dac.'):
state_dict_dac[k[4:]] = v
# breakpoint()
# feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True)
feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True)
feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
backbone.load_state_dict(state_dict_bb, strict=True)
head.load_state_dict(state_dict_hd, strict=True)
self.feature_extractor = feature_extractor.to(self.device)
self.backbone = backbone.to(self.device)
self.head = head.to(self.device)
self.multiperioddisc.load_state_dict(state_dict_mp, strict=True)
self.multiresddisc.load_state_dict(state_dict_mr, strict=True)
self.dac.load_state_dict(state_dict_dac, strict=True)
def training_step(self, *args):
# print('-------------------train--------------------')
# if self.global_rank == 0 and self.resume:
# config_path = self.resume_config
# model_path = self.resume_model
# self.pretrained_load(config_path, model_path)
# print('加载预训练模型:', model_path)
bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
output = super().training_step(*args, bandwidth_id=bandwidth_id)
return output
def validation_step(self, *args):
# print('-------------------valid--------------------')
bandwidth_id = torch.tensor([0], device=self.device)
output = super().validation_step(*args, bandwidth_id=bandwidth_id)
return output
def validation_epoch_end(self, outputs):
if self.global_rank == 0:
*_, audio_in, _ = outputs[0].values()
# Resynthesis with encodec for reference
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
self.logger.experiment.add_audio(
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
)
super().validation_epoch_end(outputs)
from typing import List
import torch
import torchaudio
from torch import nn
import math
# from inspiremusic.wavtokenizer.decoder.modules import safe_log
from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder
from inspiremusic.wavtokenizer.encoder import EncodecModel
from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class FeatureExtractor(nn.Module):
"""Base class for feature extractors."""
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=padding == "center",
power=1,
)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
features = safe_log(mel)
return features
class EncodecFeatures(FeatureExtractor):
def __init__(
self,
encodec_model: str = "encodec_24khz",
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
train_codebooks: bool = False,
num_quantizers: int = 1,
dowmsamples: List[int] = [6, 5, 5, 4],
vq_bins: int = 16384,
vq_kmeans: int = 800,
):
super().__init__()
# breakpoint()
self.frame_rate = 25 # not use
# n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
n_q = num_quantizers # important
encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
decay=0.99, kmeans_init=True)
# breakpoint()
if encodec_model == "encodec_24khz":
self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
target_bandwidths=bandwidths, sample_rate=24000, channels=1)
else:
raise ValueError(
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
)
for param in self.encodec.parameters():
param.requires_grad = True
# self.num_q = n_q
# codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
# self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
self.bandwidths = bandwidths
# @torch.no_grad()
# def get_encodec_codes(self, audio):
# audio = audio.unsqueeze(1)
# emb = self.encodec.encoder(audio)
# codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
# return codes
def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
# breakpoint()
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
# codes = self.get_encodec_codes(audio)
# # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
# # with offsets given by the number of bins, and finally summed in a vectorized operation.
# offsets = torch.arange(
# 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
# )
# embeddings_idxs = codes + offsets.view(-1, 1, 1)
# features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
# return features.transpose(1, 2)
def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
\ No newline at end of file
import torch
from torch import nn
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
from inspiremusic.wavtokenizer.decoder.spectral_ops import IMDCT, ISTFT
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class ISTFTHead(FourierHead):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
self.out = nn.Linear(dim, out_dim)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
self.clip_audio = clip_audio
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.out.weight.mul_(scale.view(-1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
x = symexp(x)
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class IMDCTCosHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
super().__init__()
self.clip_audio = clip_audio
self.out = nn.Linear(dim, mdct_frame_len)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
m, p = x.chunk(2, dim=2)
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
audio = self.imdct(m * torch.cos(p))
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from pytorch_lightning import Callback
matplotlib.use("Agg")
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
"""
Save a matplotlib figure to a numpy array.
Args:
fig (Figure): Matplotlib figure object.
Returns:
ndarray: Numpy array representing the figure.
"""
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
"""
Plot a spectrogram and convert it to a numpy array.
Args:
spectrogram (ndarray): Spectrogram data.
Returns:
ndarray: Numpy array representing the plotted spectrogram.
"""
spectrogram = spectrogram.astype(np.float32)
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
class GradNormCallback(Callback):
"""
Callback to log the gradient norm.
"""
def on_after_backward(self, trainer, model):
model.log("grad_norm", gradient_norm(model))
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
"""
Compute the gradient norm.
Args:
model (Module): PyTorch model.
norm_type (float, optional): Type of the norm. Defaults to 2.0.
Returns:
Tensor: Gradient norm.
"""
grads = [p.grad for p in model.parameters() if p.grad is not None]
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
return total_norm
from typing import List, Tuple
import torch
import torchaudio
from torch import nn
from decoder.modules import safe_log
import torch.nn.functional as F
class MelSpecReconstructionLoss(nn.Module):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def __init__(
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
)
def forward(self, y_hat, y) -> torch.Tensor:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat = safe_log(self.mel_spec(y_hat))
mel = safe_log(self.mel_spec(y))
loss = torch.nn.functional.l1_loss(mel, mel_hat)
return loss
class GeneratorLoss(nn.Module):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
gen_losses.append(l)
loss += l
return loss, gen_losses
class DiscriminatorLoss(nn.Module):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def forward(
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class FeatureMatchingLoss(nn.Module):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss
class DACGANLoss(nn.Module):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
# d_fake = self.discriminator(fake.audio_data)
# d_real = self.discriminator(real.audio_data)
d_fake = self.discriminator(fake)
d_real = self.discriminator(real)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += torch.mean(x_fake[-1] ** 2)
loss_d += torch.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature
from typing import Optional
import torch
from torch import nn
from torch.nn.utils import weight_norm
from inspiremusic.wavtokenizer.decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv1d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h = q.shape
q = q.permute(0, 2, 1) # b,hw,c
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
self.temb_ch = 0
block_in = dim
dropout = 0.1
attn_type="vanilla"
pos_net : tp.List[nn.Module] = [
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
make_attn(block_in, attn_type=attn_type),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
ResnetBlock(in_channels=block_in,out_channels=block_in,
temb_channels=self.temb_ch,dropout=dropout),
Normalize(block_in)
]
self.pos_net = nn.Sequential(*pos_net)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(x)
x = self.pos_net(x)
if self.adanorm:
# assert bandwidth_id is not None
if bandwidth_id is None:
bandwidth_id = torch.tensor(0, device='cuda')
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x
from typing import Optional
from typing import Tuple
import torch
from torch import nn
from torch.nn.utils import weight_norm, remove_weight_norm
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, ...] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: float = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
]
)
self.gamma = nn.ParameterList(
[
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
import os
from typing import Tuple, Any, Union, Dict
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
from inspiremusic.wavtokenizer.decoder.heads import FourierHead
from inspiremusic.wavtokenizer.decoder.models import Backbone
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class WavTokenizer(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in model.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_hparams_feat(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained_feat(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def estimator(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams_feat(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
model.load_state_dict(state_dict)
model.eval()
return model
@classmethod
def from_pretrained0911(self, config_path, model_folder_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0802(config_path)
models = os.listdir(model_folder_path)
val_loss = []
for item in models:
if not item.startswith('vocos_'):
continue
val_loss.append(item[-11:-5])
val_loss.sort()
val_loss = val_loss[:3] # 取前3性能较好的模型平均
state_dict = dict()
state_dicts = []
for item in models:
if not item.startswith('vocos_'):
continue
ll = item[-11:-5]
if ll not in val_loss:
continue
model_path = model_folder_path + '/' + item
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict_single = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict_single[k] = v
state_dicts.append(state_dict_single)
for kk in state_dicts[0].keys():
vv = state_dicts[0][kk]
for i in range(1, len(state_dicts)):
ss = state_dicts[i]
vv += ss[kk]
vm = vv/len(state_dicts)
state_dict[kk] = vm
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
audio_output = self.decode(features, **kwargs)
return audio_output
# 0818
@torch.inference_mode()
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
return features,discrete_codes
# 0818
@torch.inference_mode()
def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
return features,discrete_codes
@torch.inference_mode()
def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
_, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs)
discrete_codes = discrete_codes.clamp(min=0, max=16383)
return discrete_codes
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
features = features.transpose(1, 2)
return features
from typing import Tuple, Any, Union, Dict
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
from inspiremusic.wavtokenizer.decoder.heads import FourierHead
from inspiremusic.wavtokenizer.decoder.models import Backbone
from inspiremusic.wavtokenizer.decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class WavTokenizer(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
self.multiperioddisc = multiperioddisc
self.multiresddisc = multiresddisc
@classmethod
def from_hparams0828(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head,
multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4),
multiresddisc=MultiResolutionDiscriminator(num_embeddings=4))
return model
@classmethod
def from_pretrained0828(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0828(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \
or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'):
state_dict[k] = v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
return model
@classmethod
def from_hparams0802(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained0802(self, config_path, model_path):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model = self.from_hparams0802(config_path)
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
state_dict = dict()
for k, v in state_dict_raw.items():
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
state_dict[k] = v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
audio_output = self.decode(features, **kwargs)
return audio_output
# 0818
@torch.inference_mode()
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
features, _, _ = self.feature_extractor(audio_input, **kwargs)
return features
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = features.transpose(1, 2)
return features
import numpy as np
import scipy
import torch
from torch import nn, view_as_real, view_as_complex
import pdb
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
# assert (window_envelope > 1e-11).all()
if not torch.all(window_envelope > 1e-11):
window_envelope = torch.clamp(window_envelope, min=1e-11)
y = y / window_envelope
return y
def onnx_forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
pdb.set_trace()
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
# assert (window_envelope > 1e-11).all()
if not torch.all(window_envelope > 1e-11):
window_envelope = torch.clamp(window_envelope, min=1e-11)
y = y / window_envelope
return y
class MDCT(nn.Module):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if self.padding == "center":
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
elif self.padding == "same":
# hop_length is 1/2 frame_len
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
else:
raise ValueError("Padding must be 'center' or 'same'.")
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
N = self.frame_len // 2
x = x * self.window.expand(x.shape)
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
return torch.real(res) * np.sqrt(2)
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment