Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from .norm import ConvLayerNorm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_layer_norm', 'layer_norm', 'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == 'layer_norm':
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`.
"""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
return self.conv(x)
class SConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""LSTM layers module."""
from torch import nn
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
# def forward(self, x):
# x = x.permute(2, 0, 1)
# y, _ = self.lstm(x)
# if self.skip:
# y = y + x
# y = y.permute(1, 2, 0)
# return y
# 修改transpose顺序
def forward(self, x):
# # 插入reshape
# x = x.reshape(x.shape)
x1 = x.permute(2, 0, 1)
y, _ = self.lstm(x1)
y = y.permute(1, 2, 0)
if self.skip:
y = y + x
return y
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Normalization modules."""
import typing as tp
import einops
import torch
from torch import nn
class ConvLayerNorm(nn.LayerNorm):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, 'b ... t -> b t ...')
x = super().forward(x)
x = einops.rearrange(x, 'b t ... -> b ... t')
return
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Encodec SEANet-based encoder and decoder implementation."""
import typing as tp
import numpy as np
import torch.nn as nn
from . import (
SConv1d,
SConvTranspose1d,
SLSTM
)
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
"""
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation)
mult = 1
model: tp.List[nn.Module] = [
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
norm=norm, norm_params=norm_params,
activation=activation, activation_params=activation_params,
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
# Add downsampling layers
model += [
act(**activation_params),
SConv1d(mult * n_filters, mult * n_filters * 2,
kernel_size=ratio * 2, stride=ratio,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
mult *= 2
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
trim_right_ratio: float = 1.0):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation)
mult = int(2 ** len(self.ratios))
model: tp.List[nn.Module] = [
SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add upsampling layers
model += [
act(**activation_params),
SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
kernel_size=ratio * 2, stride=ratio,
norm=norm, norm_kwargs=norm_params,
causal=causal, trim_right_ratio=trim_right_ratio),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
activation=activation, activation_params=activation_params,
norm=norm, norm_params=norm_params, causal=causal,
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
mult //= 2
# Add final layers
model += [
act(**activation_params),
SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [
final_act(**final_activation_params)
]
self.model = nn.Sequential(*model)
def forward(self, z):
y = self.model(z)
return y
def test():
import torch
encoder = SEANetEncoder()
decoder = SEANetDecoder()
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
if __name__ == '__main__':
test()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A streamable transformer."""
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
"""Create time embedding for the given positions, target dimension `dim`.
"""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
phase = positions / (max_period ** (adim / (half_dim - 1)))
return torch.cat([
torch.cos(phase),
torch.sin(phase),
], dim=-1)
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
if self.norm_first:
sa_input = self.norm1(x)
x = x + self._sa_block(sa_input, x_past, past_context)
x = x + self._ff_block(self.norm2(x))
else:
sa_input = x
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
x = self.norm2(x + self._ff_block(x))
return x, sa_input
# self-attention block
def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
_, T, _ = x.shape
_, H, _ = x_past.shape
queries = x
keys = torch.cat([x_past, x], dim=1)
values = keys
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
delta = queries_pos - keys_pos
valid_access = (delta >= 0) & (delta <= past_context)
x = self.self_attn(queries, keys, values,
attn_mask=~valid_access,
need_weights=False)[0]
return self.dropout1(x)
class StreamingTransformerEncoder(nn.Module):
"""TransformerEncoder with streaming support.
Args:
dim (int): dimension of the data.
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
num_heads (int): number of heads.
num_layers (int): number of layers.
max_period (float): maxium period of cosines in the positional embedding.
past_context (int or None): receptive field for the causal mask, infinite if None.
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
norm_in (bool): normalize the input.
dropout (float): dropout probability.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
norm_in: bool = True, dropout: float = 0., **kwargs):
super().__init__()
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.max_period = max_period
self.past_context = past_context
activation: tp.Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
else:
self.norm_in = nn.Identity()
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
StreamingTransformerEncoderLayer(
dim, num_heads, hidden_dim,
activation=activation, batch_first=True, dropout=dropout, **kwargs))
def forward(self, x: torch.Tensor,
states: tp.Optional[tp.List[torch.Tensor]] = None,
offset: tp.Union[int, torch.Tensor] = 0):
B, T, C = x.shape
if states is None:
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
new_state: tp.List[torch.Tensor] = []
x = self.norm_in(x)
x = x + pos_emb
for layer_state, layer in zip(states, self.layers):
x, new_layer_state = layer(x, layer_state, self.past_context)
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
new_state.append(new_layer_state[:, -self.past_context:, :])
return x, new_state, offset + T
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MS-STFT discriminator, provided here for reference."""
import typing as tp
import torchaudio
import torch
from torch import nn
from einops import rearrange
from .modules import NormConv2d
FeatureMapType = tp.List[torch.Tensor]
LogitsType = torch.Tensor
DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
class DiscriminatorSTFT(nn.Module):
"""STFT sub-discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_fft (int): Size of FFT for each scale. Default: 1024
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
win_length (int): Window size for each scale. Default: 1024
normalized (bool): Whether to normalize by magnitude after stft. Default: True
norm (str): Normalization method. Default: `'weight_norm'`
activation (str): Activation function. Default: `'LeakyReLU'`
activation_params (dict): Parameters to provide to the activation function.
growth (int): Growth factor for the filters. Default: 1
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
self.filters = filters
self.in_channels = in_channels
self.out_channels = out_channels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
self.activation = getattr(torch.nn, activation)(**activation_params)
self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
normalized=self.normalized, center=False, pad_mode=None, power=None)
spec_channels = 2 * self.in_channels
self.convs = nn.ModuleList()
self.convs.append(
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
)
in_chs = min(filters_scale * self.filters, max_filters)
for i, dilation in enumerate(dilations):
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
norm=norm))
in_chs = out_chs
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm))
self.conv_post = NormConv2d(out_chs, self.out_channels,
kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
z = torch.cat([z.real, z.imag], dim=1)
z = rearrange(z, 'b c w t -> b c t w')
for i, layer in enumerate(self.convs):
z = layer(z)
z = self.activation(z)
fmap.append(z)
z = self.conv_post(z)
return z, fmap
class MultiScaleSTFTDiscriminator(nn.Module):
"""Multi-Scale STFT (MS-STFT) discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_ffts (Sequence[int]): Size of FFT for each scale
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
win_lengths (Sequence[int]): Window size for each scale
**kwargs: additional args for STFTDiscriminator
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
super().__init__()
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
self.discriminators = nn.ModuleList([
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
for i in range(len(n_ffts))
])
self.num_discriminators = len(self.discriminators)
def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
def test():
disc = MultiScaleSTFTDiscriminator(filters=32)
y = torch.randn(1, 1, 24000)
y_hat = torch.randn(1, 1, 24000)
y_disc_r, fmap_r = disc(y)
y_disc_gen, fmap_gen = disc(y_hat)
assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators
assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
if __name__ == '__main__':
test()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp
import torch
from ..binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
roundoff: float = 1e-8, min_range: int = 2,
check: bool = True) -> torch.Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2 ** total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: torch.Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2 ** self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
return outs
def flush(self):
"""Flush the remaining information to the stream.
"""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
self._last: tp.Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
self.current -= (b1 << self.max_bit)
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2 ** self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return (mid, low, high, self.current)
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
import warnings
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from .. import distrib
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #data不变
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
distrib.broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
# breakpoint()
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
# warnings.warn('When using RVQ in training model, first check '
# 'https://github.com/facebookresearch/encodec/issues/25 . '
# 'The bug wasn\'t fixed here for reproducibility.')
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
all_indices.append(indices)
quantized = layer.decode(indices)
residual = residual - quantized.detach()
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
class LanguageVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
# print("core_vq.py:self.layers",self.layers)
def forward(self, x, n_q: tp.Optional[int] = None):
# breakpoint() x[b,t,c] #[64,75,128]
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
# breakpoint()
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized_out, indices, loss = layer(residual) #得到该层的表征,该层的indices,该层的loss [64,75]
# residual = residual - quantized.detach()
# quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
# breakpoint()
# breakpoint()
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
all_indices.append(indices)
quantized = layer.decode(indices)
residual = residual - quantized.detach()
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import math
import typing as tp
import torch
from torch import nn
from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
# print(self.bins)
# breakpoint()
self.vq = LanguageVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
# self.vq = ResidualVectorQuantization(
# dim=self.dimension,
# codebook_size=self.bins,
# num_quantizers=self.n_q,
# decay=self.decay,
# kmeans_init=self.kmeans_init,
# kmeans_iters=self.kmeans_iters,
# threshold_ema_dead_code=self.threshold_ema_dead_code,
# )
def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
frame_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
# breakpoint()
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
# assert n_q==4
# breakpoint()
# nq_choice=[3,4,8]
nq_choice=[4,6,8]
if self.training:
# choice = int(torch.randint(0, 3, (1,)).item())
choice = int(torch.randint(0, 3, (1,)).item())
# breakpoint()
n_q=nq_choice[choice]
# breakpoint()
# n_q=8
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
frame_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
# n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
# # assert n_q==4
# # breakpoint()
# # nq_choice=[3,4,8]
# nq_choice=[3,4,5,6,7,8]
# if self.training:
# # choice = int(torch.randint(0, 3, (1,)).item())
# choice = int(torch.randint(0, 6, (1,)).item())
# # breakpoint()
# n_q=nq_choice[choice]
n_q=1
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int:
"""Return n_q based on specified target bandwidth.
"""
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.:
# bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
# bandwidth == 6.0
n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return n_q
def get_bandwidth_per_quantizer(self, frame_rate: int):
"""Return bandwidth per quantizer for a given input frame rate.
Each quantizer encodes a frame with lg(bins) bits.
"""
return math.log2(self.bins) * frame_rate
def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizers to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
codes = self.vq.encode(x, n_q=n_q)
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
"""
quantized = self.vq.decode(codes)
return quantized
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities."""
from hashlib import sha256
from pathlib import Path
import typing as tp
import torch
import torchaudio
def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
# Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
# e.g., more than 2 frames per position.
# The core idea is to use a weight function that is a triangle,
# with a maximum value at the middle of the segment.
# We use this weighting when summing the frames, and divide by the sum of weights
# for each positions at the end. Thus:
# - if a frame is the only one to cover a position, the weighting is a no-op.
# - if 2 frames cover a position:
# ... ...
# / \/ \
# / /\ \
# S T , i.e. S offset of second frame starts, T end of first frame.
# Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
# After the final normalization, the weight of the second frame at position `t` is
# (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
#
# - if more than 2 frames overlap at a given point, we hope that by induction
# something sensible happens.
assert len(frames)
device = frames[0].device
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
frame_length = frames[0].shape[-1]
t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
weight = 0.5 - (t - 0.5).abs()
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
out[..., offset:offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset:offset + frame_length] += weight[:frame_length]
offset += stride
assert sum_weight.min() > 0
return out / sum_weight
def _get_checkpoint_url(root_url: str, checkpoint: str):
if not root_url.endswith('/'):
root_url += '/'
return root_url + checkpoint
def _check_checksum(path: Path, checksum: str):
sha = sha256()
with open(path, 'rb') as file:
while True:
buf = file.read(2**20)
if not buf:
break
sha.update(buf)
actual_checksum = sha.hexdigest()[:len(checksum)]
if actual_checksum != checksum:
raise RuntimeError(f'Invalid checksum for file {path}, '
f'expected {checksum} but got {actual_checksum}')
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions"
assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo."
*shape, channels, length = wav.shape
if target_channels == 1:
wav = wav.mean(-2, keepdim=True)
elif target_channels == 2:
wav = wav.expand(*shape, target_channels, length)
elif channels == 1:
wav = wav.expand(target_channels, -1)
else:
raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}")
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
return wav
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
sample_rate: int, rescale: bool = False):
limit = 0.99
mx = wav.abs().max()
if rescale:
wav = wav * min(limit / mx, 1)
else:
wav = wav.clamp(-limit, limit)
torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
# 模型编码
modelCode=1412
# 模型名称
modelName=InspireMusic_pytorch
# 模型描述
modelDescription=支持音乐、歌曲及音频的生成,为用户提供多样化选择。
# 应用场景
appScenario=推理,音乐生成,广媒,影视,动漫,医疗,家居,教育
# 框架类型
frameType=pytorch
---
license: apache-2.0
language:
- en
pipeline_tag: text-to-audio
tags:
- music_generation
---
[//]: # (# InspireMusic)
<p align="center">
<a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">
<img alt="logo" src="./asset/logo.png" width="100%"></a>
</p>
[//]: # (<p align="center">)
[//]: # ( <a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">)
[//]: # ( <img alt="InspireMusic" src="https://svg-banners.vercel.app/api?type=origin&text1=Inspire%20Music🎶&text2=🤗%20A%20Fundamental%20Music%20Song%20Audio%20Generation%20Toolkit&width=800&height=210"></a>)
[//]: # (</p>)
<p align="center">
<a href="https://iris2c.github.io/InspireMusic" target="_blank">
<img alt="Demo" src="https://img.shields.io/badge/Demo%20👈🏻-InspireMusic?labelColor=%20%23FDB062&label=InspireMusic&color=%20%23f79009"></a>
<a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">
<img alt="Code" src="https://img.shields.io/badge/Code%20⭐-InspireMusic?labelColor=%20%237372EB&label=InspireMusic&color=%20%235462eb"></a>
<a href="https://modelscope.cn/models/iic/InspireMusic-1.5B-Long" target="_blank">
<img alt="Model" src="https://img.shields.io/badge/InspireMusic-Model-green"></a>
<a href="https://huggingface.co/spaces/FunAudioLLM/InspireMusic" target="_blank">
<img alt="Space" src="https://img.shields.io/badge/Spaces-ModelScope-pink?labelColor=%20%237b8afb&label=Spaces&color=%20%230a5af8"></a>
<a href="https://huggingface.co/spaces/FunAudioLLM/InspireMusic" target="_blank">
<img alt="Space" src="https://img.shields.io/badge/HuggingFace-Spaces?labelColor=%20%239b8afb&label=Spaces&color=%20%237a5af8"></a>
<a href="https://arxiv.org/abs/" target="_blank">
<img alt="Paper" src="https://img.shields.io/badge/arXiv-Paper-lightgrey"></a>
<a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">
[//]: # (<a href="https://huggingface.co/FunAudioLLM/InspireMusic-Base" target="_blank">)
[//]: # ( <img alt="Model" src="https://img.shields.io/badge/Model-InspireMusic?labelColor=%20%23FDA199&label=InspireMusic&color=orange"></a>)
[//]: # (<a href="https://arxiv.org/abs/" target="_blank">)
[//]: # ( <img alt="Paper" src="https://img.shields.io/badge/Paper-arXiv?labelColor=%20%23528bff&label=arXiv&color=%20%23155EEF"></a>)
[//]: # (<a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">)
[//]: # ( <img alt="Githube Star" src="https://img.shields.io/github/stars/FunAudioLLM/InspireMusic"></a>)
[//]: # (<a href="https://github.com/FunAudioLLM/InspireMusic/blob/main/asset/QR.jpg" target="_blank">)
[//]: # ( <img src="https://img.shields.io/badge/group%20chat-group?&labelColor=%20%235462eb&color=%20%235462eb" alt="chat on WeChat"></a>)
[//]: # (<a href="https://discord.gg/nSPpRU7fRr" target="_blank">)
[//]: # ( <img src="https://img.shields.io/badge/discord-chat?&labelColor=%20%235462eb&color=%20%235462eb" alt="chat on Discord"></a>)
[//]: # ( <a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">)
[//]: # ( <img alt="Static Badge" src="https://img.shields.io/badge/v0.1-version?logo=free&color=%20%23155EEF&label=version&labelColor=%20%23528bff"></a>)
[//]: # (<a href="https://github.com/FunAudioLLM/InspireMusic/graphs/commit-activity" target="_blank">)
[//]: # (<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/FunAudioLLM/InspireMusic?labelColor=%20%2332b583&color=%20%2312b76a"></a>)
[//]: # ( <a href="https://github.com/FunAudioLLM/InspireMusic" target="_blank">)
[//]: # ( <img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3AFunAudioLLM%2FInspireMusic%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>)
[//]: # ( <a href="https://github.com/FunAudioLLM/InspireMusic/discussions/" target="_blank">)
[//]: # ( <img alt="Discussion posts" src="https://img.shields.io/github/discussions/FunAudioLLM/InspireMusic?labelColor=%20%239b8afb&color=%20%237a5af8"></a>)
</p>
InspireMusic is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch.
![GitHub Repo stars](https://img.shields.io/github/stars/FunAudioLLM/InspireMusic) Please support our community project 💖 by starring it on GitHub 加⭐支持 🙏
---
<a name="Highligts"></a>
## Highlights
**InspireMusic** focuses on music generation, song generation and audio generation.
- A unified framework for music/song/audio generation. Controllable with text prompts, music genres, music structures, etc.
- Support music generation tasks with high audio quality, with available sampling rates of 24kHz, 48kHz.
- Support long-form audio generation.
- Convenient fine-tuning and inference. Support mixed precision training (FP16, FP32). Provide convenient fine-tuning and inference scripts and strategies, allowing users to easily fine-tune their music generation models.
<a name="What's News"></a>
## What's New 🔥
- 2025/02: InspireMusic demo is available on [ModelScope Space](https://modelscope.cn/studios/iic/InspireMusic/summary) and [HuggingFace Space](https://huggingface.co/spaces/FunAudioLLM/InspireMusic).
- 2025/01: Open-source [InspireMusic-Base](https://modelscope.cn/models/iic/InspireMusic/summary), [InspireMusic-Base-24kHz](https://modelscope.cn/models/iic/InspireMusic-Base-24kHz/summary), [InspireMusic-1.5B](https://modelscope.cn/models/iic/InspireMusic-1.5B/summary), [InspireMusic-1.5B-24kHz](https://modelscope.cn/models/iic/InspireMusic-1.5B-24kHz/summary), [InspireMusic-1.5B-Long](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/summary) models for music generation. Models are available on both ModelScope and HuggingFace.
- 2024/12: Support to generate 48kHz audio with super resolution flow matching.
- 2024/11: Welcome to preview 👉🏻 [**InspireMusic Demos**](https://iris2c.github.io/InspireMusic) 👈🏻. We're excited to share this with you and are working hard to bring even more features and models soon. Your support and feedback mean a lot to us!
- 2024/11: We are thrilled to announce the open-sourcing of the **InspireMusic** [code repository](https://github.com/FunAudioLLM/InspireMusic) and [demos](https://iris2c.github.io/InspireMusic). **InspireMusic** is a unified framework for music, song, and audio generation, featuring capabilities such as text-to-music conversion, music structure, genre control, and timestamp management. InspireMusic stands out for its exceptional music generation and instruction-following abilities.
## Introduction
> [!Note]
> This repo contains the algorithm infrastructure and some simple examples. Currently only support English text prompts.
> [!Tip]
> To explore the performance, please refer to [InspireMusic Demo Page](https://iris2c.github.io/InspireMusic). We will open-source better & larger models soon.
InspireMusic is a unified music, song and audio generation framework through the audio tokenization and detokenization process integrated with a large autoregressive transformer. The original motive of this toolkit is to empower the common users to innovate soundscapes and enhance euphony in research through music, song, and audio crafting. The toolkit provides both inference and training code for AI generative models that create high-quality music. Featuring a unified framework, InspireMusic incorporates autoregressive Transformer and conditional flow-matching modeling (CFM), allowing for the controllable generation of music, songs, and audio with both textual and structural music conditioning, as well as neural audio tokenizers. Currently, the toolkit supports text-to-music generation and plans to expand its capabilities to include text-to-song and text-to-audio generation in the future.
## InspireMusic
<p align="center">
<table>
<tr>
<td style="text-align:center;">
<img alt="Light" src="asset/InspireMusic.png" width="100%" />
</tr>
<tr>
<td style="text-align:center;">
<b>Figure 1.</b> An overview of the InspireMusic framework.
We introduce InspireMusic, a unified framework for music, song and audio generation, capable of producing 48kHz long-form audio. InspireMusic employs an autoregressive transformer to generate music tokens in response to textual input. Complementing this, an ODE-based diffusion model, specifically flow matching, is utilized to reconstruct latent features from these generated music tokens. Then a vocoder generates audio waveforms from the reconstructed features. for input text, an ODE-based diffusion model, flow matching, to reconstruct latent features from the generated music tokens, and a vocoder to generate audio waveforms. InspireMusic is capable of text-to-music, music continuation, music reconstruction, and music super resolution tasks. It employs WavTokenizer as an audio tokenizer to convert 24kHz audio into 75Hz discrete tokens, while HifiCodec serves as a music tokenizer, transforming 48kHz audio into 150Hz latent features compatible with the flow matching model.
</td>
</tr>
</table>
</p>
## Installation
### Clone
- Clone the repo
``` sh
git clone --recursive https://github.com/FunAudioLLM/InspireMusic.git
# If you failed to clone submodule due to network failures, please run the following command until success
cd InspireMusic
git submodule update --init --recursive
```
### Install
InspireMusic requires Python 3.8, PyTorch 2.0.1. To install InspireMusic, you can run one of the following:
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:
``` sh
conda create -n inspiremusic python=3.8
conda activate inspiremusic
cd InspireMusic
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platforms.
conda install -y -c conda-forge pynini==2.1.5
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# install flash attention to speedup training, support version 2.6.3
pip install flash-attn --no-build-isolation
```
Currently support on CUDA Version 11.x.
- Install within the package:
```sh
cd InspireMusic
# You can run to install the packages
python setup.py install
pip install flash-attn --no-build-isolation
```
We also recommend having `sox` or `ffmpeg` installed, either through your system or Anaconda:
```sh
# # Install sox
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel
# Install ffmpeg
# ubuntu
sudo apt-get install ffmpeg
# centos
sudo yum install ffmpeg
```
### Quick Start
Here is a quick example inference script for music generation.
``` sh
cd InspireMusic
mkdir -p pretrained_models
# Download models
# ModelScope
git clone https://www.modelscope.cn/iic/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic-1.5B-Long
# HuggingFace
git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic-1.5B-Long
cd examples/music_generation
# run a quick inference example
bash infer_1.5b_long.sh
```
Here is a quick start running script to run music generation task including data preparation pipeline, model training, inference.
``` sh
cd InspireMusic/examples/music_generation/
bash run.sh
```
### One-line Inference
#### Text-to-music Task
One-line Shell script for text-to-music task.
``` sh
cd examples/music_generation
# with flow matching
# use one-line command to get a quick try
python -m inspiremusic.cli.inference
# custom the config like the following one-line command
python -m inspiremusic.cli.inference --task text-to-music -m "InspireMusic-1.5B-Long" -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." -c intro -s 0.0 -e 30.0 -r "exp/inspiremusic" -o output -f wav
# without flow matching
python -m inspiremusic.cli.inference --task text-to-music -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." --fast True
```
Alternatively, you can run the inference with just a few lines of Python code.
```python
from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
set_env_variables()
model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
model.inference("text-to-music", "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")
```
#### Music Continuation Task
One-line Shell script for music continuation task.
``` sh
cd examples/music_generation
# with flow matching
python -m inspiremusic.cli.inference --task continuation -g 0 -a audio_prompt.wav
# without flow matching
python -m inspiremusic.cli.inference --task continuation -g 0 -a audio_prompt.wav --fast True
```
Alternatively, you can run the inference with just a few lines of Python code.
```python
from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
set_env_variables()
model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
# just use audio prompt
model.inference("continuation", None, "audio_prompt.wav")
# use both text prompt and audio prompt
model.inference("continuation", "Continue to generate jazz music.", "audio_prompt.wav")
```
## Models
### Download Model
We strongly recommend that you download our pretrained `InspireMusic model`.
If you are an expert in this field, and you are only interested in training your own InspireMusic model from scratch, you can skip this step.
``` sh
# git模型下载,请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic
```
### Available Models
Currently, we open source the music generation models support 24KHz mono and 48KHz stereo audio.
The table below presents the links to the ModelScope and Huggingface model hub. More models will be available soon.
| Model name | Model Links | Remarks |
|---------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------|
| InspireMusic-Base-24kHz | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-Base-24kHz/summary) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz) | Pre-trained Music Generation Model, 24kHz mono, 30s |
| InspireMusic-Base | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic/summary) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-Base) | Pre-trained Music Generation Model, 48kHz, 30s |
| InspireMusic-1.5B-24kHz | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B-24kHz/summary) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz) | Pre-trained Music Generation 1.5B Model, 24kHz mono, 30s |
| InspireMusic-1.5B | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B/summary) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B) | Pre-trained Music Generation 1.5B Model, 48kHz, 30s |
| InspireMusic-1.5B-Long ⭐ | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/summary) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long) | Pre-trained Music Generation 1.5B Model, 48kHz, support long-form music generation more than 5mins |
| InspireSong-1.5B | [![model](https://img.shields.io/badge/ModelScope-Model-lightgrey.svg)]() [![model](https://img.shields.io/badge/HuggingFace-Model-lightgrey.svg)]() | Pre-trained Song Generation 1.5B Model, 48kHz stereo |
| InspireAudio-1.5B | [![model](https://img.shields.io/badge/ModelScope-Model-lightgrey.svg)]() [![model](https://img.shields.io/badge/HuggingFace-Model-lightgrey.svg)]() | Pre-trained Audio Generation 1.5B Model, 48kHz stereo |
| Wavtokenizer[<sup>[1]</sup>](https://openreview.net/forum?id=yBlVlS2Fd9) (75Hz) | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/file/view/master?fileName=wavtokenizer%252Fmodel.pt) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long/tree/main/wavtokenizer) | An extreme low bitrate audio tokenizer for music with one codebook at 24kHz audio. |
| Music_tokenizer (75Hz) | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B-24kHz/file/view/master?fileName=music_tokenizer%252Fmodel.pt) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz/tree/main/music_tokenizer) | A music tokenizer based on HifiCodec<sup>[2]</sup> at 24kHz audio. |
| Music_tokenizer (150Hz) | [![model](https://img.shields.io/badge/ModelScope-Model-green.svg)](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/file/view/master?fileName=music_tokenizer%252Fmodel.pt) [![model](https://img.shields.io/badge/HuggingFace-Model-green.svg)](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long/tree/main/music_tokenizer) | A music tokenizer based on HifiCodec at 48kHz audio. |
## Basic Usage
At the moment, InspireMusic contains the training code and inference code for [music generation](https://github.com/FunAudioLLM/InspireMusic/tree/main/examples/music_generation). More tasks such as song generation and audio generation will be supported in future.
### Training
Here is an example to train LLM model, support FP16 training.
```sh
torchrun --nnodes=1 --nproc_per_node=8 \
--rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
inspiremusic/bin/train.py \
--train_engine "torch_ddp" \
--config conf/inspiremusic.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--model llm \
--model_dir `pwd`/exp/music_generation/llm/ \
--tensorboard_dir `pwd`/tensorboard/music_generation/llm/ \
--ddp.dist_backend "nccl" \
--num_workers 8 \
--prefetch 100 \
--pin_memory \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer \
--fp16
```
Here is an example code to train flow matching model, does not support FP16 training.
```sh
torchrun --nnodes=1 --nproc_per_node=8 \
--rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
inspiremusic/bin/train.py \
--train_engine "torch_ddp" \
--config conf/inspiremusic.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--model flow \
--model_dir `pwd`/exp/music_generation/flow/ \
--tensorboard_dir `pwd`/tensorboard/music_generation/flow/ \
--ddp.dist_backend "nccl" \
--num_workers 8 \
--prefetch 100 \
--pin_memory \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
```
### Inference
Here is an example script to quickly do model inference.
``` sh
cd InspireMusic/examples/music_generation/
bash infer.sh
```
Here is an example code to run inference with normal mode, i.e., with flow matching model for text-to-music and music continuation tasks.
```sh
pretrained_model_dir = "./pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
python inspiremusic/bin/inference.py --task $task \
--gpu 0 \
--config conf/inspiremusic.yaml \
--prompt_data data/test/parquet/data.list \
--flow_model $pretrained_model_dir/flow.pt \
--llm_model $pretrained_model_dir/llm.pt \
--music_tokenizer $pretrained_model_dir/music_tokenizer \
--wavtokenizer $pretrained_model_dir/wavtokenizer \
--result_dir `pwd`/exp/inspiremusic/${task}_test \
--chorus verse \
--min_generate_audio_seconds 8 \
--max_generate_audio_seconds 30
done
```
Here is an example code to run inference with fast mode, i.e., without flow matching model for text-to-music and music continuation tasks.
```sh
pretrained_model_dir = "./pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
python inspiremusic/bin/inference.py --task $task \
--gpu 0 \
--config conf/inspiremusic.yaml \
--prompt_data data/test/parquet/data.list \
--flow_model $pretrained_model_dir/flow.pt \
--llm_model $pretrained_model_dir/llm.pt \
--music_tokenizer $pretrained_model_dir/music_tokenizer \
--wavtokenizer $pretrained_model_dir/wavtokenizer \
--result_dir `pwd`/exp/inspiremusic/${task}_test \
--chorus verse \
--fast \
--min_generate_audio_seconds 8 \
--max_generate_audio_seconds 30
done
```
## Disclaimer
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
# --extra-index-url https://download.pytorch.org/whl/cu118
conformer==0.3.2
# deepspeed==0.14.2; sys_platform == 'linux'
diffusers==0.27.2
gdown==5.1.0
gradio==4.32.2
grpcio==1.57.0
grpcio-tools==1.57.0
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnx==1.17.0
# onnxruntime-gpu==1.16.0; sys_platform == 'linux'
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
openai-whisper==20231117
protobuf==4.25
pydantic==2.7.0
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
# torch==2.0.1
# torchaudio==2.0.2
uvicorn==0.30.0
wget==3.2
fastapi==0.111.0
fastapi-cli==0.0.4
WeTextProcessing==1.0.3
transformers
accelerate
huggingface-hub==0.25.2
julius
# https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""InspireMusic setup script."""
import os
from setuptools import find_packages
from setuptools import setup
requirements = {
"install": [
"setuptools",
"conformer==0.3.2",
"diffusers==0.27.2",
"gdown==5.1.0",
"gradio==5.5.0",
"grpcio==1.57.0",
"grpcio-tools==1.57.0",
"hydra-core==1.3.2",
"HyperPyYAML==1.2.2",
"inflect==7.3.1",
"librosa==0.10.2",
"lightning==2.2.4",
"matplotlib==3.7.5",
"modelscope==1.15.0",
"networkx==3.1",
"omegaconf==2.3.0",
"onnx==1.17.0",
"protobuf==4.25",
"pydantic==2.7.0",
"rich==13.7.1",
"soundfile==0.12.1",
"tensorboard==2.14.0",
"torch==2.0.1",
"torchaudio==2.0.2",
"uvicorn==0.30.0",
"wget==3.2",
"fastapi==0.111.0",
"fastapi-cli==0.0.4",
"WeTextProcessing==1.0.3",
"accelerate",
"huggingface-hub==0.25.2",
"julius",
"onnxruntime-gpu==1.16.0",
"onnxruntime==1.16.0",
"transformers",
],
# train: The modules invoked when training only.
"train": [
"deepspeed==0.14.2",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
"all": [
# NOTE(kamo): Append modules requiring specific pytorch version or torch>2.0
"transformers",
"openai-whisper==20231117",
],
"setup": [
"numpy",
],
"test": [
"pytest>=3.3.0",
],
}
requirements["all"].extend(requirements["train"])
requirements["test"].extend(requirements["train"])
install_requires = requirements["install"]
setup_requires = requirements["setup"]
tests_require = requirements["test"]
extras_require = {k: v for k, v in requirements.items() if k not in ["install", "setup"]}
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "inspiremusic", "version.txt")
with open(version_file, "r") as f:
version = f.read().strip()
setup(
name="inspiremusic",
version=version,
url="https://github.com/FunAudioLLM/InspireMusic.git",
author="Tongyi Lab, Alibaba Group",
author_email="chong.zhang@alibaba-inc.com",
description="InspireMusic: A Fundamental Music, Song and Audio Generation Framework and Toolkits",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
long_description_content_type="text/markdown",
license="The MIT License",
packages=find_packages(include=["inspiremusic*"]),
package_data={"inspiremusic": ["version.txt"]},
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require=extras_require,
python_requires=">=3.8.0",
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={
"console_scripts": [
"inspiremusic = inspiremusic.bin.inference:main",
"inspiremusic-train = inspiremusic.bin.train:main",
]
},
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment