Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
from .base import CodecMixin
from .base import DACFile
from .dac import DAC
from .discriminator import Discriminator
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import numpy as np
import torch
import tqdm
from audiotools import AudioSignal
from torch import nn
SUPPORTED_VERSIONS = ["1.0.0"]
@dataclass
class DACFile:
codes: torch.Tensor
# Metadata
chunk_length: int
original_length: int
input_db: float
channels: int
sample_rate: int
padding: bool
dac_version: str
def save(self, path):
artifacts = {
"codes": self.codes.numpy().astype(np.uint16),
"metadata": {
"input_db": self.input_db.numpy().astype(np.float32),
"original_length": self.original_length,
"sample_rate": self.sample_rate,
"chunk_length": self.chunk_length,
"channels": self.channels,
"padding": self.padding,
"dac_version": SUPPORTED_VERSIONS[-1],
},
}
path = Path(path).with_suffix(".dac")
with open(path, "wb") as f:
np.save(f, artifacts)
return path
@classmethod
def load(cls, path):
artifacts = np.load(path, allow_pickle=True)[()]
codes = torch.from_numpy(artifacts["codes"].astype(int))
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
raise RuntimeError(
f"Given file {path} can't be loaded with this version of descript-audio-codec."
)
return cls(codes=codes, **artifacts["metadata"])
class CodecMixin:
@property
def padding(self):
if not hasattr(self, "_padding"):
self._padding = True
return self._padding
@padding.setter
def padding(self, value):
assert isinstance(value, bool)
layers = [
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
]
for layer in layers:
if value:
if hasattr(layer, "original_padding"):
layer.padding = layer.original_padding
else:
layer.original_padding = layer.padding
layer.padding = tuple(0 for _ in range(len(layer.padding)))
self._padding = value
def get_delay(self):
# Any number works here, delay is invariant to input length
l_out = self.get_output_length(0)
L = l_out
layers = []
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
layers.append(layer)
for layer in reversed(layers):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.ConvTranspose1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.Conv1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.ceil(L)
l_in = L
return (l_in - l_out) // 2
def get_output_length(self, input_length):
L = input_length
# Calculate output length
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.Conv1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.ConvTranspose1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.floor(L)
return L
@torch.no_grad()
def compress(
self,
audio_path_or_signal: Union[str, Path, AudioSignal],
win_duration: float = 1.0,
verbose: bool = False,
normalize_db: float = -16,
n_quantizers: int = None,
) -> DACFile:
"""Processes an audio signal from a file or AudioSignal object into
discrete codes. This function processes the signal in short windows,
using constant GPU memory.
Parameters
----------
audio_path_or_signal : Union[str, Path, AudioSignal]
audio signal to reconstruct
win_duration : float, optional
window duration in seconds, by default 5.0
verbose : bool, optional
by default False
normalize_db : float, optional
normalize db, by default -16
Returns
-------
DACFile
Object containing compressed codes and metadata
required for decompression
"""
audio_signal = audio_path_or_signal
if isinstance(audio_signal, (str, Path)):
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
self.eval()
original_padding = self.padding
original_device = audio_signal.device
audio_signal = audio_signal.clone()
original_sr = audio_signal.sample_rate
resample_fn = audio_signal.resample
loudness_fn = audio_signal.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if audio_signal.signal_duration >= 10 * 60 * 60:
resample_fn = audio_signal.ffmpeg_resample
loudness_fn = audio_signal.ffmpeg_loudness
original_length = audio_signal.signal_length
resample_fn(self.sample_rate)
input_db = loudness_fn()
if normalize_db is not None:
audio_signal.normalize(normalize_db)
audio_signal.ensure_max_of_audio()
nb, nac, nt = audio_signal.audio_data.shape
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
win_duration = (
audio_signal.signal_duration if win_duration is None else win_duration
)
if audio_signal.signal_duration <= win_duration:
# Unchunked compression (used if signal length < win duration)
self.padding = True
n_samples = nt
hop = nt
else:
# Chunked inference
self.padding = False
# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
hop = self.get_output_length(n_samples)
codes = []
range_fn = range if not verbose else tqdm.trange
for i in range_fn(0, nt, hop):
x = audio_signal[..., i : i + n_samples]
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
audio_data = x.audio_data.to(self.device)
audio_data = self.preprocess(audio_data, self.sample_rate)
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
codes.append(c.to(original_device))
chunk_length = c.shape[-1]
codes = torch.cat(codes, dim=-1)
dac_file = DACFile(
codes=codes,
chunk_length=chunk_length,
original_length=original_length,
input_db=input_db,
channels=nac,
sample_rate=original_sr,
padding=self.padding,
dac_version=SUPPORTED_VERSIONS[-1],
)
if n_quantizers is not None:
codes = codes[:, :n_quantizers, :]
self.padding = original_padding
return dac_file
@torch.no_grad()
def decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
"""Reconstruct audio from a given .dac file
Parameters
----------
obj : Union[str, Path, DACFile]
.dac file location or corresponding DACFile object.
verbose : bool, optional
Prints progress if True, by default False
Returns
-------
AudioSignal
Object with the reconstructed audio
"""
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
import math
from typing import List
from typing import Union
import numpy as np
import torch
from audiotools import AudioSignal
from audiotools.ml import BaseModel
from torch import nn
from .base import CodecMixin
from indextts.s2mel.dac.nn.layers import Snake1d
from indextts.s2mel.dac.nn.layers import WNConv1d
from indextts.s2mel.dac.nn.layers import WNConvTranspose1d
from indextts.s2mel.dac.nn.quantize import ResidualVectorQuantize
from .encodec import SConv1d, SConvTranspose1d, SLSTM
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
super().__init__()
conv1d_type = SConv1d# if causal else WNConv1d
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
Snake1d(dim),
conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
super().__init__()
conv1d_type = SConv1d# if causal else WNConv1d
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1, causal=causal),
ResidualUnit(dim // 2, dilation=3, causal=causal),
ResidualUnit(dim // 2, dilation=9, causal=causal),
Snake1d(dim // 2),
conv1d_type(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
causal=causal,
norm='weight_norm',
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
causal: bool = False,
lstm: int = 2,
):
super().__init__()
conv1d_type = SConv1d# if causal else WNConv1d
# Create first convolution
self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
# Add LSTM if needed
self.use_lstm = lstm
if lstm:
self.block += [SLSTM(d_model, lstm)]
# Create last convolution
self.block += [
Snake1d(d_model),
conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
def reset_cache(self):
# recursively find all submodules named SConv1d in self.block and use their reset_cache method
def reset_cache(m):
if isinstance(m, SConv1d) or isinstance(m, SLSTM):
m.reset_cache()
return
for child in m.children():
reset_cache(child)
reset_cache(self.block)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
super().__init__()
conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
self.block = nn.Sequential(
Snake1d(input_dim),
conv1d_type(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
causal=causal,
norm='weight_norm'
),
ResidualUnit(output_dim, dilation=1, causal=causal),
ResidualUnit(output_dim, dilation=3, causal=causal),
ResidualUnit(output_dim, dilation=9, causal=causal),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
causal: bool = False,
lstm: int = 2,
):
super().__init__()
conv1d_type = SConv1d# if causal else WNConv1d
# Add first conv layer
layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
if lstm:
layers += [SLSTM(channels, num_layers=lstm)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
# Add final conv layer
layers += [
Snake1d(output_dim),
conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(BaseModel, CodecMixin):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: bool = False,
sample_rate: int = 44100,
lstm: int = 2,
causal: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
lstm=lstm,
causal=causal,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
n_quantizers: int = None,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z = self.encoder(audio_data)
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
z, n_quantizers
)
return z, codes, latents, commitment_loss, codebook_loss
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return self.decoder(z)
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = None,
n_quantizers: int = None,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
z, codes, latents, commitment_loss, codebook_loss = self.encode(
audio_data, n_quantizers
)
x = self.decode(z)
return {
"audio": x[..., :length],
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
if __name__ == "__main__":
import numpy as np
from functools import partial
model = DAC().to("cpu")
for n, m in model.named_modules():
o = m.extra_repr()
p = sum([np.prod(p.size()) for p in m.parameters()])
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
setattr(m, "extra_repr", partial(fn, o=o, p=p))
print(model)
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
length = 88200 * 2
x = torch.randn(1, 1, length).to(model.device)
x.requires_grad_(True)
x.retain_grad()
# Make a forward pass
out = model(x)["audio"]
print("Input shape:", x.shape)
print("Output shape:", out.shape)
# Create gradient variable
grad = torch.zeros_like(out)
grad[:, :, grad.shape[-1] // 2] = 1
# Make a backward pass
out.backward(grad)
# Check non-zero values
gradmap = x.grad.squeeze(0)
gradmap = (gradmap != 0).sum(0) # sum across features
rf = (gradmap != 0).sum()
print(f"Receptive field: {rf.item()}")
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
model.decompress(model.compress(x, verbose=True), verbose=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
from audiotools import AudioSignal
from audiotools import ml
from audiotools import STFTParams
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
act = kwargs.pop("act", True)
conv = weight_norm(nn.Conv1d(*args, **kwargs))
if not act:
return conv
return nn.Sequential(conv, nn.LeakyReLU(0.1))
def WNConv2d(*args, **kwargs):
act = kwargs.pop("act", True)
conv = weight_norm(nn.Conv2d(*args, **kwargs))
if not act:
return conv
return nn.Sequential(conv, nn.LeakyReLU(0.1))
class MPD(nn.Module):
def __init__(self, period):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
]
)
self.conv_post = WNConv2d(
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
)
def pad_to_period(self, x):
t = x.shape[-1]
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
return x
def forward(self, x):
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class MSD(nn.Module):
def __init__(self, rate: int = 1, sample_rate: int = 44100):
super().__init__()
self.convs = nn.ModuleList(
[
WNConv1d(1, 16, 15, 1, padding=7),
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
WNConv1d(1024, 1024, 5, 1, padding=2),
]
)
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
self.sample_rate = sample_rate
self.rate = rate
def forward(self, x):
x = AudioSignal(x, self.sample_rate)
x.resample(self.sample_rate // self.rate)
x = x.audio_data
fmap = []
for l in self.convs:
x = l(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
class MRD(nn.Module):
def __init__(
self,
window_length: int,
hop_factor: float = 0.25,
sample_rate: int = 44100,
bands: list = BANDS,
):
"""Complex multi-band spectrogram discriminator.
Parameters
----------
window_length : int
Window length of STFT.
hop_factor : float, optional
Hop factor of the STFT, defaults to ``0.25 * window_length``.
sample_rate : int, optional
Sampling rate of audio in Hz, by default 44100
bands : list, optional
Bands to run discriminator over.
"""
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.sample_rate = sample_rate
self.stft_params = STFTParams(
window_length=window_length,
hop_length=int(window_length * hop_factor),
match_stride=True,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
ch = 32
convs = lambda: nn.ModuleList(
[
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
def spectrogram(self, x):
x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
x = torch.view_as_real(x.stft())
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for layer in stack:
band = layer(band)
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
x = self.conv_post(x)
fmap.append(x)
return fmap
class Discriminator(nn.Module):
def __init__(
self,
rates: list = [],
periods: list = [2, 3, 5, 7, 11],
fft_sizes: list = [2048, 1024, 512],
sample_rate: int = 44100,
bands: list = BANDS,
):
"""Discriminator that combines multiple discriminators.
Parameters
----------
rates : list, optional
sampling rates (in Hz) to run MSD at, by default []
If empty, MSD is not used.
periods : list, optional
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
fft_sizes : list, optional
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
sample_rate : int, optional
Sampling rate of audio in Hz, by default 44100
bands : list, optional
Bands to run MRD at, by default `BANDS`
"""
super().__init__()
discs = []
discs += [MPD(p) for p in periods]
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
self.discriminators = nn.ModuleList(discs)
def preprocess(self, y):
# Remove DC offset
y = y - y.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
return y
def forward(self, x):
x = self.preprocess(x)
fmaps = [d(x) for d in self.discriminators]
return fmaps
if __name__ == "__main__":
disc = Discriminator()
x = torch.zeros(1, 1, 44100)
results = disc(x)
for i, result in enumerate(results):
print(f"disc{i}")
for i, r in enumerate(result):
print(r.shape, r.mean(), r.min(), r.max())
print()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
import typing as tp
import einops
class ConvLayerNorm(nn.LayerNorm):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, 'b ... t -> b t ...')
x = super().forward(x)
x = einops.rearrange(x, 'b t ... -> b ... t')
return
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_layer_norm', 'layer_norm', 'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == 'layer_norm':
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`.
"""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect', **kwargs):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
self.cache_enabled = False
def reset_cache(self):
"""Reset the cache when starting a new stream."""
self.cache = None
self.cache_enabled = True
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
if self.cache_enabled and self.cache is not None:
# Concatenate the cache (previous inputs) with the new input for streaming
x = torch.cat([self.cache, x], dim=2)
else:
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
# Store the most recent input frames for future cache use
if self.cache_enabled:
if self.cache is None:
# Initialize cache with zeros (at the start of streaming)
self.cache = torch.zeros(B, C, kernel_size - 1, device=x.device)
# Update the cache by storing the latest input frames
if kernel_size > 1:
self.cache = x[:, :, -kernel_size + 1:].detach() # Only store the necessary frames
return self.conv(x)
class SConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
self.hidden = None
self.cache_enabled = False
def forward(self, x):
x = x.permute(2, 0, 1)
if self.training or not self.cache_enabled:
y, _ = self.lstm(x)
else:
y, self.hidden = self.lstm(x, self.hidden)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
def reset_cache(self):
self.hidden = None
self.cache_enabled = True
\ No newline at end of file
from . import layers
from . import loss
from . import quantize
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
import typing
from typing import List
import torch
import torch.nn.functional as F
from audiotools import AudioSignal
from audiotools import STFTParams
from torch import nn
class L1Loss(nn.L1Loss):
"""L1 Loss between AudioSignals. Defaults
to comparing ``audio_data``, but any
attribute of an AudioSignal can be used.
Parameters
----------
attribute : str, optional
Attribute of signal to compare, defaults to ``audio_data``.
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
self.attribute = attribute
self.weight = weight
super().__init__(**kwargs)
def forward(self, x: AudioSignal, y: AudioSignal):
"""
Parameters
----------
x : AudioSignal
Estimate AudioSignal
y : AudioSignal
Reference AudioSignal
Returns
-------
torch.Tensor
L1 loss between AudioSignal attributes.
"""
if isinstance(x, AudioSignal):
x = getattr(x, self.attribute)
y = getattr(y, self.attribute)
return super().forward(x, y)
class SISDRLoss(nn.Module):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Parameters
----------
scaling : int, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : int, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def __init__(
self,
scaling: int = True,
reduction: str = "mean",
zero_mean: int = True,
clip_min: int = None,
weight: float = 1.0,
):
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()
def forward(self, x: AudioSignal, y: AudioSignal):
eps = 1e-8
# nb, nc, nt
if isinstance(x, AudioSignal):
references = x.audio_data
estimates = y.audio_data
else:
references = x
estimates = y
nb = references.shape[0]
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
# samples now on axis 1
if self.zero_mean:
mean_reference = references.mean(dim=1, keepdim=True)
mean_estimate = estimates.mean(dim=1, keepdim=True)
else:
mean_reference = 0
mean_estimate = 0
_references = references - mean_reference
_estimates = estimates - mean_estimate
references_projection = (_references**2).sum(dim=-2) + eps
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
scale = (
(references_on_estimates / references_projection).unsqueeze(1)
if self.scaling
else 1
)
e_true = scale * _references
e_res = _estimates - e_true
signal = (e_true**2).sum(dim=1)
noise = (e_res**2).sum(dim=1)
sdr = -10 * torch.log10(signal / noise + eps)
if self.clip_min is not None:
sdr = torch.clamp(sdr, min=self.clip_min)
if self.reduction == "mean":
sdr = sdr.mean()
elif self.reduction == "sum":
sdr = sdr.sum()
return sdr
class MultiScaleSTFTLoss(nn.Module):
"""Computes the multi-scale STFT loss from [1].
Parameters
----------
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
loss += self.log_weight * self.loss_fn(
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
return loss
class MelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = {
"window_length": s.window_length,
"hop_length": s.hop_length,
"window_type": s.window_type,
}
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
loss += self.log_weight * self.loss_fn(
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
return loss
class GANLoss(nn.Module):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
d_fake = self.discriminator(fake.audio_data)
d_real = self.discriminator(real.audio_data)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += torch.mean(x_fake[-1] ** 2)
loss_d += torch.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
from indextts.s2mel.dac.nn.layers import WNConv1d
class VectorQuantizeLegacy(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
removed in-out projection
"""
def __init__(self, input_dim: int, codebook_size: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook = nn.Embedding(codebook_size, input_dim)
def forward(self, z, z_mask=None):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
z_e = z
z_q, indices = self.decode_latents(z)
if z_mask is not None:
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
else:
commitment_loss = F.mse_loss(z_e, z_q.detach())
codebook_loss = F.mse_loss(z_q, z_e.detach())
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
return z_q, indices, z_e, commitment_loss, codebook_loss
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
def forward(self, z, z_mask=None):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
if z_mask is not None:
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
else:
commitment_loss = F.mse_loss(z_e, z_q.detach())
codebook_loss = F.mse_loss(z_q, z_e.detach())
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
0
]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
if __name__ == "__main__":
rvq = ResidualVectorQuantize(quantizer_dropout=True)
x = torch.randn(16, 512, 80)
y = rvq(x)
print(y["latents"].shape)
from pathlib import Path
import argbind
from audiotools import ml
import indextts.s2mel.dac as dac
DAC = dac.model.DAC
Accelerator = ml.Accelerator
__MODEL_LATEST_TAGS__ = {
("44khz", "8kbps"): "0.0.1",
("24khz", "8kbps"): "0.0.4",
("16khz", "8kbps"): "0.0.5",
("44khz", "16kbps"): "1.0.0",
}
__MODEL_URLS__ = {
(
"44khz",
"0.0.1",
"8kbps",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
(
"24khz",
"0.0.4",
"8kbps",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
(
"16khz",
"0.0.5",
"8kbps",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
(
"44khz",
"1.0.0",
"16kbps",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
}
@argbind.bind(group="download", positional=True, without_prefix=True)
def download(
model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
):
"""
Function that downloads the weights file from URL if a local cache is not found.
Parameters
----------
model_type : str
The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
model_bitrate: str
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
Only 44khz model supports 16kbps.
tag : str
The tag of the model to download. Defaults to "latest".
Returns
-------
Path
Directory path required to load model via audiotools.
"""
model_type = model_type.lower()
tag = tag.lower()
assert model_type in [
"44khz",
"24khz",
"16khz",
], "model_type must be one of '44khz', '24khz', or '16khz'"
assert model_bitrate in [
"8kbps",
"16kbps",
], "model_bitrate must be one of '8kbps', or '16kbps'"
if tag == "latest":
tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
if download_link is None:
raise ValueError(
f"Could not find model with tag {tag} and model type {model_type}"
)
local_path = (
Path.home()
/ ".cache"
/ "descript"
/ "dac"
/ f"weights_{model_type}_{model_bitrate}_{tag}.pth"
)
if not local_path.exists():
local_path.parent.mkdir(parents=True, exist_ok=True)
# Download the model
import requests
response = requests.get(download_link)
if response.status_code != 200:
raise ValueError(
f"Could not download model. Received response code {response.status_code}"
)
local_path.write_bytes(response.content)
return local_path
def load_model(
model_type: str = "44khz",
model_bitrate: str = "8kbps",
tag: str = "latest",
load_path: str = None,
):
if not load_path:
load_path = download(
model_type=model_type, model_bitrate=model_bitrate, tag=tag
)
generator = DAC.load(load_path)
return generator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment