Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from .conv import (
pad1d,
unpad1d,
NormConv1d,
NormConvTranspose1d,
NormConv2d,
NormConvTranspose2d,
SConv1d,
SConvTranspose1d,
)
from .lstm import SLSTM
from .seanet import SEANetEncoder, SEANetDecoder
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from .norm import ConvLayerNorm
CONV_NORMALIZATIONS = frozenset(
[
"none",
"weight_norm",
"spectral_norm",
"time_layer_norm",
"layer_norm",
"time_group_norm",
]
)
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "weight_norm":
return weight_norm(module)
elif norm == "spectral_norm":
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
) -> nn.Module:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == "layer_norm":
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == "time_group_norm":
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose1d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose2d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = "reflect",
):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn(
"SConv1d has been initialized with stride > 1 and dilation > 1"
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
)
self.conv = NormConv1d(
in_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
padding_total = (kernel_size - 1) * dilation - (stride - 1)
extra_padding = get_extra_padding_for_conv1d(
x, kernel_size, stride, padding_total
)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
)
return self.conv(x)
class SConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
causal: bool = False,
norm: str = "none",
trim_right_ratio: float = 1.0,
norm_kwargs: tp.Dict[str, tp.Any] = {},
):
super().__init__()
self.convtr = NormConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert (
self.causal or self.trim_right_ratio == 1.0
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""LSTM layers module."""
from torch import nn
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(
self,
dimension: int,
num_layers: int = 2,
skip: bool = True,
bidirectional: bool = False,
):
super().__init__()
self.bidirectional = bidirectional
self.skip = skip
self.lstm = nn.LSTM(
dimension, dimension, num_layers, bidirectional=bidirectional
)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.bidirectional:
x = x.repeat(1, 1, 2)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Normalization modules."""
import typing as tp
import einops
import torch
from torch import nn
class ConvLayerNorm(nn.LayerNorm):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def __init__(
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, "b ... t -> b t ...")
x = super().forward(x)
x = einops.rearrange(x, "b t ... -> b ... t")
return
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp
import torch
from ..binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(
pdf: torch.Tensor,
total_range_bits: int,
roundoff: float = 1e-8,
min_range: int = 2,
check: bool = True,
) -> torch.Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2**total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
if (
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: torch.Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2**self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (
effective_low,
effective_high,
range_low,
range_high,
)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
return outs
def flush(self):
"""Flush the remaining information to the stream."""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
self._last: tp.Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
self.current -= b1 << self.max_bit
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2**self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return (mid, low, high, self.current)
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from .distrib import broadcast_tensors, rank
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
out_quantized = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
if layers and i in layers:
out_quantized.append(quantized)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
st = st or 0
for layer in self.layers[st:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
# print('params[0].device ', params[0].device)
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(
f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one."
)
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
# src = int(rank()) # added code
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
)
else:
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import math
import typing as tp
import torch
from torch import nn
from .core_vq import ResidualVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
layers (list): Layer that need to return quantized. Defalt: None.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated numbert quantizers and layer quantized required to return.
"""
n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q:
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
st (int): Start to encode input from which layers. Default: 0.
"""
n_q = n_q if n_q else self.n_q
st = st or 0
codes = self.vq.encode(x, n_q=n_q, st=st)
return codes
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
Args:
codes (torch.Tensor): Input indices for each quantizer.
st (int): Start to decode input codes from which layers. Default: 0.
"""
quantized = self.vq.decode(codes, st=st)
return quantized
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Encodec SEANet-based encoder and decoder implementation."""
import typing as tp
import numpy as np
import torch.nn as nn
import torch
from . import SConv1d, SConvTranspose1d, SLSTM
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
"""
def __init__(
self,
dim: int,
kernel_sizes: tp.List[int] = [3, 1],
dilations: tp.List[int] = [1, 1],
activation: str = "ELU",
activation_params: dict = {"alpha": 1.0},
norm: str = "weight_norm",
norm_params: tp.Dict[str, tp.Any] = {},
causal: bool = False,
pad_mode: str = "reflect",
compress: int = 2,
true_skip: bool = True,
):
super().__init__()
assert len(kernel_sizes) == len(
dilations
), "Number of kernel sizes should match number of dilations"
act = getattr(nn, activation) if activation != "Snake" else Snake1d
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params) if activation != "Snake" else act(in_chs),
SConv1d(
in_chs,
out_chs,
kernel_size=kernel_size,
dilation=dilation,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = SConv1d(
dim,
dim,
kernel_size=1,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2],
activation: str = "ELU",
activation_params: dict = {"alpha": 1.0},
norm: str = "weight_norm",
norm_params: tp.Dict[str, tp.Any] = {},
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 3,
dilation_base: int = 2,
causal: bool = False,
pad_mode: str = "reflect",
true_skip: bool = False,
compress: int = 2,
lstm: int = 2,
bidirectional: bool = False,
):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios) # 计算乘积
act = getattr(nn, activation) if activation != "Snake" else Snake1d
mult = 1
model: tp.List[nn.Module] = [
SConv1d(
channels,
mult * n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(
mult * n_filters,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1],
norm=norm,
norm_params=norm_params,
activation=activation,
activation_params=activation_params,
causal=causal,
pad_mode=pad_mode,
compress=compress,
true_skip=true_skip,
)
]
# Add downsampling layers
model += [
(
act(**activation_params)
if activation != "Snake"
else act(mult * n_filters)
),
SConv1d(
mult * n_filters,
mult * n_filters * 2,
kernel_size=ratio * 2,
stride=ratio,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
mult *= 2
if lstm:
model += [
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
]
mult = mult * 2 if bidirectional else mult
model += [
(
act(**activation_params)
if activation != "Snake"
else act(mult * n_filters)
),
SConv1d(
mult * n_filters,
dimension,
last_kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2],
activation: str = "ELU",
activation_params: dict = {"alpha": 1.0},
final_activation: tp.Optional[str] = None,
final_activation_params: tp.Optional[dict] = None,
norm: str = "weight_norm",
norm_params: tp.Dict[str, tp.Any] = {},
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 3,
dilation_base: int = 2,
causal: bool = False,
pad_mode: str = "reflect",
true_skip: bool = False,
compress: int = 2,
lstm: int = 2,
trim_right_ratio: float = 1.0,
bidirectional: bool = False,
):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation) if activation != "Snake" else Snake1d
mult = int(2 ** len(self.ratios))
model: tp.List[nn.Module] = [
SConv1d(
dimension,
mult * n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
]
if lstm:
model += [
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add upsampling layers
model += [
(
act(**activation_params)
if activation != "Snake"
else act(mult * n_filters)
),
SConvTranspose1d(
mult * n_filters,
mult * n_filters // 2,
kernel_size=ratio * 2,
stride=ratio,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
trim_right_ratio=trim_right_ratio,
),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(
mult * n_filters // 2,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1],
activation=activation,
activation_params=activation_params,
norm=norm,
norm_params=norm_params,
causal=causal,
pad_mode=pad_mode,
compress=compress,
true_skip=true_skip,
)
]
mult //= 2
# Add final layers
model += [
act(**activation_params) if activation != "Snake" else act(n_filters),
SConv1d(
n_filters,
channels,
last_kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [final_act(**final_activation_params)]
self.model = nn.Sequential(*model)
def forward(self, z):
y = self.model(z)
return y
def test():
import torch
encoder = SEANetEncoder()
decoder = SEANetDecoder()
x = torch.randn(1, 1, 24000)
z = encoder(x)
print("z ", z.shape)
assert 1 == 2
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
if __name__ == "__main__":
test()
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantize(nn.Module):
"""Vector quantization w/ exponential moving averages (EMA)"""
def __init__(
self,
dim: int,
codebook_size: int,
decay=0.8,
commitment=1.0,
eps=1e-5,
n_embed=None,
):
super().__init__()
n_embed = self.default(n_embed, codebook_size)
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
self.commitment = commitment
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
@property
def codebook(self):
return self.embed.transpose(0, 1)
def exists(self, val):
return val is not None
def default(self, val, d):
return val if self.exists(val) else d
def ema_inplace(self, moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(self, x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def forward(self, input):
dtype = input.dtype
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
if self.training:
self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
self.ema_inplace(self.embed_avg, embed_sum, self.decay)
cluster_size = (
self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
loss = F.mse_loss(quantize.detach(), input) * self.commitment
quantize = input + (quantize - input).detach()
avg_probs = torch.mean(embed_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return quantize, loss, perplexity
def forward_index(self, input):
dtype = input.dtype
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
quantize = input + (quantize - input).detach()
return quantize, embed_ind
class ResidualVQ(nn.Module):
"""Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantize(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x):
quantized_out = 0.0
residual = x
all_losses = []
all_perplexities = []
for layer in self.layers:
quantized, loss, perplexity = layer(residual)
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
# We found considering only the 1st layer VQ's graident results in better performance
# residual = residual - quantized.detach() # considering all layers' graidents
residual = (
residual - quantized
) # considering only the first layer's graident
quantized_out = quantized_out + quantized
all_losses.append(loss)
all_perplexities.append(perplexity)
all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
return quantized_out, all_losses, all_perplexities
def forward_index(self, x, flatten_idx=False):
"""
all_indices: [num_of_quantizers, B, T]
"""
quantized_out = 0.0
residual = x
all_indices = []
for i, layer in enumerate(self.layers):
quantized, indices = layer.forward_index(residual)
# residual = residual - quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
if flatten_idx:
indices += self.codebook_size * i
all_indices.append(indices)
all_indices = torch.stack(all_indices)
return quantized_out, all_indices
def initial(self):
self.codebook = []
for layer in self.layers:
self.codebook.append(layer.codebook)
self.codebook_size = self.codebook[0].size(0)
self.codebook = torch.stack(self.codebook)
self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
def lookup(self, indices):
quantized_out = F.embedding(indices, self.codebook) # Num x T x C
return torch.sum(quantized_out, dim=0, keepdim=True)
class Quantizer(nn.Module):
def __init__(
self,
code_dim: int,
codebook_num: int,
codebook_size: int,
):
super().__init__()
self.codebook = ResidualVQ(
dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size
)
def initial(self):
self.codebook.initial()
def forward(self, z):
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, vqloss, perplexity
def inference(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, indices
def encode(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
return zq, indices
def decode(self, indices):
z = self.codebook.lookup(indices)
return z
class Conv1d1x1(nn.Conv1d):
"""1x1 Conv1d."""
def __init__(self, in_channels, out_channels, bias=True):
super(Conv1d1x1, self).__init__(
in_channels, out_channels, kernel_size=1, bias=bias
)
class Conv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = -1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
if padding < 0:
padding = (kernel_size - 1) // 2 * dilation
self.dilation = dilation
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C, T).
"""
x = self.conv(x)
return x
class ConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding=-1,
output_padding=-1,
groups=1,
bias=True,
):
super().__init__()
if padding < 0:
padding = (stride + 1) // 2
if output_padding < 0:
output_padding = 1 if stride % 2 else 0
self.deconv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C', T').
"""
x = self.deconv(x)
return x
class ResidualUnit(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
dilation=1,
bias=False,
nonlinear_activation="ELU",
nonlinear_activation_params={},
):
super().__init__()
self.activation = getattr(nn, nonlinear_activation)(
**nonlinear_activation_params
)
self.conv1 = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
bias=bias,
)
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
def forward(self, x):
y = self.conv1(self.activation(x))
y = self.conv2(self.activation(y))
return x + y
class Projector(nn.Module):
def __init__(
self, input_channels: int, code_dim: int, kernel_size=3, stride=1, bias=False
):
super().__init__()
self.project = Conv1d(
input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias
)
def forward(self, x):
return self.project(x)
class EncoderBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
dilations=(1, 1),
unit_kernel_size=3,
bias=True,
):
super().__init__()
self.res_units = torch.nn.ModuleList()
for dilation in dilations:
self.res_units += [
ResidualUnit(
in_channels,
in_channels,
kernel_size=unit_kernel_size,
dilation=dilation,
)
]
self.num_res = len(self.res_units)
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(
3 if stride == 1 else (2 * stride)
), # special case: stride=1, do not use kernel=2
stride=stride,
bias=bias,
)
def forward(self, x):
for idx in range(self.num_res):
x = self.res_units[idx](x)
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
input_channels: int,
encode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv = Conv1d(
in_channels=input_channels,
out_channels=encode_channels,
kernel_size=kernel_size,
stride=1,
bias=False,
)
self.conv_blocks = torch.nn.ModuleList()
in_channels = encode_channels
for idx, stride in enumerate(strides):
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
self.conv_blocks += [
EncoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
in_channels = out_channels
self.num_blocks = len(self.conv_blocks)
self.out_channels = out_channels
def forward(self, x):
x = self.conv(x)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
return x
class DecoderBlock(nn.Module):
"""Decoder block (no up-sampling)"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
dilations=(1, 1),
unit_kernel_size=3,
bias=True,
):
super().__init__()
if stride == 1:
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
stride=stride,
bias=bias,
)
else:
self.conv = ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(2 * stride),
stride=stride,
bias=bias,
)
self.res_units = torch.nn.ModuleList()
for idx, dilation in enumerate(dilations):
self.res_units += [
ResidualUnit(
out_channels,
out_channels,
kernel_size=unit_kernel_size,
dilation=dilation,
)
]
self.num_res = len(self.res_units)
def forward(self, x):
x = self.conv(x)
for idx in range(self.num_res):
x = self.res_units[idx](x)
return x
class Decoder(nn.Module):
def __init__(
self,
code_dim: int,
output_channels: int,
decode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv1 = Conv1d(
in_channels=code_dim,
out_channels=int(decode_channels * channel_ratios[0]),
kernel_size=kernel_size,
stride=1,
bias=False,
)
self.conv_blocks = torch.nn.ModuleList()
for idx, stride in enumerate(strides):
in_channels = int(decode_channels * channel_ratios[idx])
if idx < (len(channel_ratios) - 1):
out_channels = int(decode_channels * channel_ratios[idx + 1])
else:
out_channels = decode_channels
self.conv_blocks += [
DecoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
self.num_blocks = len(self.conv_blocks)
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
def forward(self, z):
x = self.conv1(z)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
x = self.conv2(x)
return x
class VevoRepCodec(nn.Module):
def __init__(
self,
input_channels=768,
output_channels=768,
encode_channels=768,
decode_channels=768,
code_dim=768,
codebook_num=1,
codebook_size=1024,
bias=True,
enc_ratios=(1, 1),
dec_ratios=(1, 1),
enc_strides=(1, 1),
dec_strides=(1, 1),
enc_kernel_size=3,
dec_kernel_size=3,
enc_block_dilations=(1, 1),
enc_block_kernel_size=3,
dec_block_dilations=(1, 1),
dec_block_kernel_size=3,
):
super().__init__()
self.input_channels = input_channels
self.encoder = Encoder(
input_channels=input_channels,
encode_channels=encode_channels,
channel_ratios=enc_ratios,
strides=enc_strides,
kernel_size=enc_kernel_size,
bias=bias,
block_dilations=enc_block_dilations,
unit_kernel_size=enc_block_kernel_size,
)
self.decoder = Decoder(
code_dim=code_dim,
output_channels=output_channels,
decode_channels=decode_channels,
channel_ratios=dec_ratios,
strides=dec_strides,
kernel_size=dec_kernel_size,
bias=bias,
block_dilations=dec_block_dilations,
unit_kernel_size=dec_block_kernel_size,
)
self.projector = Projector(
input_channels=self.encoder.out_channels,
code_dim=code_dim,
kernel_size=3,
stride=1,
bias=False,
)
self.quantizer = Quantizer(
code_dim=code_dim, codebook_num=codebook_num, codebook_size=codebook_size
)
def forward(self, x):
x = self.encoder(x)
z = self.projector(x)
zq, vqloss, perplexity = self.quantizer(z)
y = self.decoder(zq)
return y, zq, z, vqloss, perplexity
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
import torch
import torch.nn.functional as F
import numpy as np
import os
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import math
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
# sinusoidal positional encoding
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :] * 1.0
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class LlamaAdaptiveRMSNorm(nn.Module):
def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
super().__init__()
self.to_weight = nn.Linear(dim_cond, hidden_size)
nn.init.zeros_(self.to_weight.weight)
nn.init.ones_(self.to_weight.bias)
self.variance_epsilon = eps
self._is_hf_initialized = True # disable automatic init
def forward(self, hidden_states, cond_embedding):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
weight = self.to_weight(cond_embedding)
if len(weight.shape) == 2:
weight = weight.unsqueeze(1)
return (weight * hidden_states).to(input_dtype)
class LlamaNARDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
"""Override to adaptive layer norm"""
super().__init__(config, layer_idx) # init attention, mlp, etc.
self.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
# add `cond` in forward function
def forward(
self,
hidden_states: torch.Tensor,
cond_embedding: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states, cond_embedding=cond_embedding
)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states, cond_embedding=cond_embedding
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def __init__(self, config: LlamaConfig, layer_idx: int):
"""Override to adaptive layer norm"""
super().__init__(config, layer_idx) # init attention, mlp, etc.
self.layer_idx = layer_idx
self.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
def forward(
self,
hidden_states: torch.Tensor,
cond_embedding: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states, cond_embedding=cond_embedding
)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states, cond_embedding=cond_embedding
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class DiffLlama(LlamaModel):
def __init__(
self,
hidden_size=1024,
num_heads=16,
num_layers=16,
config=LlamaConfig(0, 256, 1024, 1, 1),
):
super().__init__(config)
self.layers = nn.ModuleList(
[
LlamaNARDecoderLayer(
LlamaConfig(
hidden_size=hidden_size,
num_attention_heads=num_heads,
max_position_embeddings=4096,
intermediate_size=hidden_size * 4,
),
layer_idx=i,
)
for i in range(num_layers)
]
)
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
self.diff_step_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, hidden_size),
)
# self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0)
self.cond_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, hidden_size),
)
for layer in self.layers:
layer.input_layernorm = LlamaAdaptiveRMSNorm(
hidden_size, dim_cond=hidden_size
)
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
hidden_size, dim_cond=hidden_size
)
self.post_init()
# self.reset_parameters()
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
x,
diffusion_step,
cond,
x_mask,
input_ids: torch.LongTensor = None, # [num_quant, B, T]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# retrieve some shape info
batch_size, seq_length, _ = x.shape
# condtion mlp
cond_embedding = self.cond_mlp(cond) # (B, T, C)
# diffusion step embedding
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
x = x + cond_embedding
inputs_embeds = x
attention_mask = x_mask
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
raise NotImplementedError
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cond_embedding=diffusion_step,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
return hidden_states
class DiffLlamaPrefix(LlamaModel):
def __init__(
self,
hidden_size=1024,
num_heads=16,
num_layers=16,
config=LlamaConfig(0, 256, 1024, 1, 1),
):
super().__init__(config)
self.layers = nn.ModuleList(
[
LlamaNARDecoderLayer(
LlamaConfig(
hidden_size=hidden_size,
num_attention_heads=num_heads,
max_position_embeddings=4096,
intermediate_size=hidden_size * 4,
),
layer_idx=i,
)
for i in range(num_layers)
]
)
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
self.diff_step_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, hidden_size),
)
self.cond_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, hidden_size),
)
for layer in self.layers:
layer.input_layernorm = LlamaAdaptiveRMSNorm(
hidden_size, dim_cond=hidden_size
)
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
hidden_size, dim_cond=hidden_size
)
self.embed_tokens = None
self.post_init()
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
x,
diffusion_step,
x_mask,
phone_embedding: Optional[torch.LongTensor] = None,
phone_mask: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None, # [num_quant, B, T]
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# retrieve some shape info
phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C)
phone_length = phone_embedding.shape[1]
inputs_embeds = torch.cat([phone_embedding, x], dim=1)
attention_mask = torch.cat([phone_mask, x_mask], dim=1)
# diffusion step embedding
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
batch_size, seq_length, _ = inputs_embeds.shape
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
raise NotImplementedError
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cond_embedding=diffusion_step,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
return hidden_states[
:,
phone_length:,
]
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import torch.nn as nn
import math
from einops import rearrange
from indextts.utils.maskgct.models.tts.maskgct.llama_nar import DiffLlama
def top_k(logits, thres=0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim=-1)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(2, ind, val)
return probs
def log(t, eps=1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
def top_k(logits, thres=0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim=-1)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(2, ind, val)
return probs
def log(t, eps=1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
class MaskGCT_S2A(nn.Module):
def __init__(
self,
num_quantizer=12,
hidden_size=1024,
num_layers=16,
num_heads=16,
codebook_size=1024,
cfg_scale=0.15,
mask_layer_schedule="linear",
cond_codebook_size=1024,
cond_dim=1024,
predict_layer_1=True,
cfg=None,
):
super().__init__()
num_quantizer = (
cfg.num_quantizer
if cfg is not None and hasattr(cfg, "num_quantizer")
else num_quantizer
)
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
num_layers = (
cfg.num_layers
if cfg is not None and hasattr(cfg, "num_layers")
else num_layers
)
num_heads = (
cfg.num_heads
if cfg is not None and hasattr(cfg, "num_heads")
else num_heads
)
codebook_size = (
cfg.codebook_size
if cfg is not None and hasattr(cfg, "codebook_size")
else codebook_size
)
cfg_scale = (
cfg.cfg_scale
if cfg is not None and hasattr(cfg, "cfg_scale")
else cfg_scale
)
mask_layer_schedule = (
cfg.mask_layer_schedule
if cfg is not None and hasattr(cfg, "mask_layer_schedule")
else mask_layer_schedule
)
cond_codebook_size = (
cfg.cond_codebook_size
if cfg is not None and hasattr(cfg, "cond_codebook_size")
else cond_codebook_size
)
cond_dim = (
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
)
predict_layer_1 = (
cfg.predict_layer_1
if cfg is not None and hasattr(cfg, "predict_layer_1")
else predict_layer_1
)
self.num_quantizer = num_quantizer
self.hidden_size = hidden_size
self.codebook_size = codebook_size
self.num_layers = num_layers
self.num_heads = num_heads
self.cfg_scale = cfg_scale
self.mask_layer_schedule = mask_layer_schedule
self.cond_codebook_size = cond_codebook_size
self.cond_dim = cond_dim
self.predict_layer_1 = predict_layer_1
self.layer_emb = nn.Embedding(self.num_quantizer, self.hidden_size)
self.mask_emb = nn.Embedding(1, self.hidden_size)
self.token_emb = torch.nn.ModuleList(
[
nn.Embedding(self.codebook_size, self.hidden_size)
for _ in range(self.num_quantizer)
]
)
self.to_logits = torch.nn.ModuleList(
[
nn.Linear(self.hidden_size, self.codebook_size)
for _ in range(self.num_quantizer)
]
)
self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
self.reset_parameters()
self.diff_estimator = DiffLlama(
hidden_size=hidden_size,
num_heads=self.num_heads,
num_layers=num_layers,
)
def mask_prob(self, t):
return torch.sin(t * np.pi / 2).to(t.device)
def mask_layer(self, t):
# print(self.predict_layer_1)
if self.mask_layer_schedule == "uniform":
if self.predict_layer_1:
mask_layer = torch.randint(0, self.num_quantizer, (1,)).to(t.device)
else:
mask_layer = torch.randint(1, self.num_quantizer, (1,)).to(t.device)
elif self.mask_layer_schedule == "cosine":
if self.predict_layer_1:
weights = torch.tensor(
[
np.cos(i / self.num_quantizer * np.pi / 2)
for i in range(self.num_quantizer)
]
)
else:
weights = torch.tensor(
[0]
+ [
np.cos((i - 1) / self.num_quantizer * np.pi / 2)
for i in range(1, self.num_quantizer)
]
)
mask_layer = torch.multinomial(weights, 1).to(t.device)
elif self.mask_layer_schedule == "linear":
if self.predict_layer_1:
weights = torch.tensor(
[self.num_quantizer - i for i in range(self.num_quantizer)]
)
else:
weights = torch.tensor(
[0]
+ [
self.num_quantizer - (i - 1)
for i in range(1, self.num_quantizer)
]
)
weights = weights / weights.sum()
mask_layer = torch.multinomial(weights, 1).to(t.device)
# print(mask_layer)
new_t = t
return mask_layer, new_t
def forward_diffusion(self, x0, t):
# x0: (B, T, num_quantizer)
mask_layer, new_t = self.mask_layer(t) # (1,)
mask_prob = self.mask_prob(new_t) # (B,)
mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size)
xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
cfg_scale = self.cfg_scale
# get prompt len
if torch.rand(1) > cfg_scale:
prompt_len = torch.randint(
min(x0.shape[1] // 4, 5), x0.shape[1] // 2, (x0.shape[0],)
).to(
x0.device
) # (B,)
else:
prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
# get is prompt
is_prompt = torch.zeros_like(x0[:, :, 0]) # (B, T)
col_indices = (
torch.arange(is_prompt.shape[1])
.repeat(is_prompt.shape[0], 1)
.to(prompt_len)
) # (B, T)
is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
for idx, token_emb_idx in enumerate(self.token_emb):
if idx < mask_layer:
xt = xt + token_emb_idx(x0[:, :, idx]) # (B, T, hidden_size)
elif idx == mask_layer:
mask = torch.bernoulli(
torch.ones_like(x0[:, :, idx]) * mask_prob[..., None]
) # mask if 1, not mask if 0
# prompt part don't need to be masked
mask[is_prompt.bool()] = 0
# Ensure at least one token is masked
mask_num = mask[:,].sum(dim=1, keepdim=False)
all_zero_mask = (mask_num == 0).bool()
row_indices_to_modify = torch.nonzero(all_zero_mask)
# mask the first token if all tokens are not masked (may mask pad if random indices)
mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
mask = mask[..., None] # (B, T, 1)
xt = (
xt
+ mask * mask_token[:, None, :]
+ (1 - mask) * token_emb_idx(x0[:, :, idx])
) # (B, T, hidden_size)
else:
# prompt part don't need to be masked
xt = (
xt
+ token_emb_idx(x0[:, :, idx]) * is_prompt[..., None]
+ mask_token * (1 - is_prompt[..., None])
)
return xt, new_t, mask_layer, mask, prompt_len, mask_prob
def loss_t(self, x0, x_mask, t, cond=None):
xt, new_t, mask_layer, mask, prompt_len, mask_prob = self.forward_diffusion(
x0, t
)
# xt: (B, T, hidden_size)
# new_t: (B,)
# mask_layer: (1,)
# mask: (B, T, 1) mask if 1, not mask if 0
# prompt_len: (B,)
# mask_prob: (B,)
mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(1) # (1, 1, hidden_size)
cond = cond + mask_layer_cond # (B, T, hidden_size)
embeds = self.diff_estimator(xt, new_t, cond, x_mask) # (B, T, hidden_size)
logits = self.to_logits[mask_layer.item()](embeds) # (B, T, codebook_size)
# final mask used for loss calculation
final_mask = mask * x_mask[..., None] # (B, T, 1)
return logits, mask_layer, final_mask, x0, prompt_len, mask_prob
def compute_loss(self, x0, x_mask, cond=None):
# x0: (B, T, num_quantizer)
# x_mask: (B, T) mask is 0 for padding
t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
t = torch.clamp(t, 1e-5, 1.0)
return self.loss_t(x0, x_mask, t, cond)
def reset_parameters(self):
def _reset_parameters(m):
if isinstance(m, nn.MultiheadAttention):
if m._qkv_same_embed_dim:
nn.init.normal_(m.in_proj_weight, std=0.02)
else:
nn.init.normal_(m.q_proj_weight, std=0.02)
nn.init.normal_(m.k_proj_weight, std=0.02)
nn.init.normal_(m.v_proj_weight, std=0.02)
if m.in_proj_bias is not None:
nn.init.constant_(m.in_proj_bias, 0.0)
nn.init.constant_(m.out_proj.bias, 0.0)
if m.bias_k is not None:
nn.init.xavier_normal_(m.bias_k)
if m.bias_v is not None:
nn.init.xavier_normal_(m.bias_v)
elif (
isinstance(m, nn.Conv1d)
or isinstance(m, nn.ConvTranspose1d)
or isinstance(m, nn.Conv2d)
or isinstance(m, nn.ConvTranspose2d)
):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
self.apply(_reset_parameters)
@torch.no_grad()
def reverse_diffusion(
self,
cond,
prompt,
x_mask=None,
prompt_mask=None,
temp=1.5,
filter_thres=0.98,
max_layer=None,
gt_code=None,
n_timesteps=[10, 4, 4, 4, 4, 4, 4, 4],
cfg=1.0,
rescale_cfg=1.0,
):
assert (
len(n_timesteps) == self.num_quantizer
) # each layer has a number of steps
prompt_code = prompt # (B, prompt_len, num_quantizer)
prompt_len = prompt_code.shape[1]
target_len = cond.shape[1] - prompt_len
if x_mask == None:
x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
if prompt_mask == None:
prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
cond.device
) # (B, prompt_len)
cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
x_mask.device
) # (B, T, hidden_size)
bsz, seq_len, _ = cum.shape
choice_temp = 1.0
start_temp = temp # temperature for sampling
start_choice_temp = choice_temp # temperature for choicing mask tokens
if max_layer is None:
max_layer = self.num_quantizer
xt = torch.LongTensor(bsz, seq_len, max_layer).to(x_mask.device)
if gt_code is not None:
gt_layer = gt_code.shape[-1]
xt[:, :, :gt_layer] = gt_code
for i in range(gt_layer):
cum += self.token_emb[i](xt[:, :, i])
else:
gt_layer = 0
for mask_layer in range(gt_layer, max_layer):
steps = n_timesteps[mask_layer]
to_logits = self.to_logits[mask_layer]
token_emb = self.token_emb[mask_layer]
mask_layer = torch.tensor(mask_layer).to(x_mask.device).long().unsqueeze(0)
mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(
1
) # (1,) -> (1, 1, hidden_size)
temp_cond = cond + mask_layer_cond # (B, T, hidden_size)
mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size)
mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
h = 1.0 / steps
# prompt_code: (B, prompt_len, num_quantizer)
cur_prompt = 0
for idx, emb in enumerate(self.token_emb):
cur_prompt = cur_prompt + emb(
prompt_code[:, :, idx]
) # (B, prompt_len, hidden_size)
t_list = [1.0 - i * h for i in range(steps)]
t_list.append(0.0)
for i in range(steps):
t = t_list[i] * torch.ones(bsz).to(x_mask.device)
token = token_emb(seq) # (B, T, hidden_size)
cur = cum + mask * mask_token[:, None, :] + (~mask) * token
cur = cur + mask_token[:, None, :] * (max_layer - 1 - mask_layer)
xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
xt_mask = torch.cat(
[prompt_mask, x_mask], dim=1
) # (B, T), mask is 0 for padding
embeds = self.diff_estimator(xt_input, t, temp_cond, xt_mask)
embeds = embeds[:, prompt_len:, :]
# cfg
if cfg > 0:
mask_embeds = self.diff_estimator(
cur, t, temp_cond[:, prompt_len:, :], x_mask
)
pos_emb_std = embeds.std() # std(g_cond)
embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
logits = to_logits(embeds) # (B, T, codebook_size)
annealing_scale = t_list[i]
choice_temp = start_choice_temp * annealing_scale
temp = start_temp * annealing_scale
logits = top_k(logits, filter_thres)
if i == steps - 1:
# greedy
if steps == 1:
temp = 0.2
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
else:
sampled_ids = logits.argmax(dim=-1)
else:
# sampling
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
scores = logits.softmax(dim=-1)
scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
scores = rearrange(scores, "b n 1 -> b n")
scores = choice_temp * gumbel_noise(scores) + scores
scores = 1 - scores
next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
if next_mask_num == 0:
break
scores = scores.masked_fill(
~mask.squeeze(-1), -torch.finfo(scores.dtype).max
)
mask_indices = scores.topk(next_mask_num, dim=-1).indices
mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
1, mask_indices, True
)
seq = seq.masked_fill(mask, 0)
mask = mask.unsqueeze(-1)
cum = cum + token_emb(seq)
xt[..., mask_layer.squeeze(0).item()] = seq
return xt
def forward(self, x0, x_mask, cond_code=None):
# x0: (B, T, num_quantizer)
# x_mask: (B, T) mask is 0 for padding
# cond_code: semantic token (B, T)
cond = self.cond_emb(cond_code)
logits, mask_layer, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
x0,
x_mask,
cond,
)
return logits, mask_layer, final_mask, x0, prompt_len, mask_prob
import torch
import librosa
# import json5
from huggingface_hub import hf_hub_download
from transformers import SeamlessM4TFeatureExtractor, Wav2Vec2BertModel
import safetensors
import numpy as np
from indextts.utils.maskgct.models.codec.kmeans.repcodec_model import RepCodec
from indextts.utils.maskgct.models.tts.maskgct.maskgct_s2a import MaskGCT_S2A
from indextts.utils.maskgct.models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
import time
def _load_config(config_fn, lowercase=False):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
Returns:
dict: dictionary that stores configurations
"""
with open(config_fn, "r") as f:
data = f.read()
config_ = json5.loads(data)
if "base_config" in config_:
# load configurations from new path
p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
p_config_ = _load_config(p_config_path)
config_ = override_config(p_config_, config_)
if lowercase:
# change keys in config_ to lower case
config_ = get_lowercase_keys_config(config_)
return config_
def load_config(config_fn, lowercase=False):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): _description_. Defaults to False.
Returns:
JsonHParams: an object that stores configurations
"""
config_ = _load_config(config_fn, lowercase=lowercase)
# create an JsonHParams object with configuration dict
cfg = JsonHParams(**config_)
return cfg
class JsonHParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = JsonHParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt', bert_path=None):
semantic_model = Wav2Vec2BertModel.from_pretrained(
# "facebook/w2v-bert-2.0"
bert_path
)
semantic_model.eval()
stat_mean_var = torch.load(path_)
semantic_mean = stat_mean_var["mean"]
semantic_std = torch.sqrt(stat_mean_var["var"])
return semantic_model, semantic_mean, semantic_std
def build_semantic_codec(cfg):
semantic_codec = RepCodec(cfg=cfg)
semantic_codec.eval()
return semantic_codec
def build_s2a_model(cfg, device):
soundstorm_model = MaskGCT_S2A(cfg=cfg)
soundstorm_model.eval()
soundstorm_model.to(device)
return soundstorm_model
def build_acoustic_codec(cfg, device):
codec_encoder = CodecEncoder(cfg=cfg.encoder)
codec_decoder = CodecDecoder(cfg=cfg.decoder)
codec_encoder.eval()
codec_decoder.eval()
codec_encoder.to(device)
codec_decoder.to(device)
return codec_encoder, codec_decoder
class Inference_Pipeline():
def __init__(
self,
semantic_model,
semantic_codec,
semantic_mean,
semantic_std,
codec_encoder,
codec_decoder,
s2a_model_1layer,
s2a_model_full,
):
self.semantic_model = semantic_model
self.semantic_codec = semantic_codec
self.semantic_mean = semantic_mean
self.semantic_std = semantic_std
self.codec_encoder = codec_encoder
self.codec_decoder = codec_decoder
self.s2a_model_1layer = s2a_model_1layer
self.s2a_model_full = s2a_model_full
@torch.no_grad()
def get_emb(self, input_features, attention_mask):
vq_emb = self.semantic_model(
input_features=input_features,
attention_mask=attention_mask,
output_hidden_states=True,
)
feat = vq_emb.hidden_states[17] # (B, T, C)
feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat)
return feat
@torch.no_grad()
def extract_acoustic_code(self, speech):
vq_emb = self.codec_encoder(speech.unsqueeze(1))
_, vq, _, _, _ = self.codec_decoder.quantizer(vq_emb)
acoustic_code = vq.permute(1, 2, 0)
return acoustic_code
@torch.no_grad()
def get_scode(self, inputs):
semantic_code, feat = self.semantic_codec.quantize(inputs)
# vq = self.semantic_codec.quantizer.vq2emb(semantic_code.unsqueeze(1))
# vq = vq.transpose(1,2)
return semantic_code
@torch.no_grad()
def semantic2acoustic(
self,
combine_semantic_code,
acoustic_code,
n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
cfg=2.5,
rescale_cfg=0.75,
):
semantic_code = combine_semantic_code
cond = self.s2a_model_1layer.cond_emb(semantic_code)
prompt = acoustic_code[:, :, :]
predict_1layer = self.s2a_model_1layer.reverse_diffusion(
cond=cond,
prompt=prompt,
temp=1.5,
filter_thres=0.98,
n_timesteps=n_timesteps[:1],
cfg=cfg,
rescale_cfg=rescale_cfg,
)
cond = self.s2a_model_full.cond_emb(semantic_code)
prompt = acoustic_code[:, :, :]
predict_full = self.s2a_model_full.reverse_diffusion(
cond=cond,
prompt=prompt,
temp=1.5,
filter_thres=0.98,
n_timesteps=n_timesteps,
cfg=cfg,
rescale_cfg=rescale_cfg,
gt_code=predict_1layer,
)
vq_emb = self.codec_decoder.vq2emb(
predict_full.permute(2, 0, 1), n_quantizers=12
)
recovered_audio = self.codec_decoder(vq_emb)
prompt_vq_emb = self.codec_decoder.vq2emb(
prompt.permute(2, 0, 1), n_quantizers=12
)
recovered_prompt_audio = self.codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()
recovered_audio = recovered_audio[0][0].cpu().numpy()
combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
return combine_audio, recovered_audio
def s2a_inference(
self,
prompt_speech_path,
combine_semantic_code,
cfg=2.5,
n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
cfg_s2a=2.5,
rescale_cfg_s2a=0.75,
):
speech = librosa.load(prompt_speech_path, sr=24000)[0]
acoustic_code = self.extract_acoustic_code(
torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device)
)
_, recovered_audio = self.semantic2acoustic(
combine_semantic_code,
acoustic_code,
n_timesteps=n_timesteps_s2a,
cfg=cfg_s2a,
rescale_cfg=rescale_cfg_s2a,
)
return recovered_audio
@torch.no_grad()
def gt_inference(
self,
prompt_speech_path,
combine_semantic_code,
):
speech = librosa.load(prompt_speech_path, sr=24000)[0]
'''
acoustic_code = self.extract_acoustic_code(
torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device)
)
prompt = acoustic_code[:, :, :]
prompt_vq_emb = self.codec_decoder.vq2emb(
prompt.permute(2, 0, 1), n_quantizers=12
)
'''
prompt_vq_emb = self.codec_encoder(torch.tensor(speech).unsqueeze(0).unsqueeze(1).to(combine_semantic_code.device))
recovered_prompt_audio = self.codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()
return recovered_prompt_audio
import re
from textstat import textstat
def contains_chinese(text):
# 正则表达式,用于匹配中文字符 + 数字 -> 都认为是 zh
if re.search(r'[\u4e00-\u9fff0-9]', text):
return True
return False
def get_text_syllable_num(text):
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
number_char_pattern = re.compile(r'[0-9]')
syllable_num = 0
tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z]+|[0-9]+', text)
# print(tokens)
if contains_chinese(text):
for token in tokens:
if chinese_char_pattern.search(token) or number_char_pattern.search(token):
syllable_num += len(token)
else:
syllable_num += textstat.syllable_count(token)
else:
syllable_num = textstat.syllable_count(text)
return syllable_num
def get_text_tts_dur(text):
min_speed = 3 # 2.18 #
max_speed = 5.50
ratio = 0.8517 if contains_chinese(text) else 1.0
syllable_num = get_text_syllable_num(text)
max_dur = syllable_num * ratio / max_speed
min_dur = syllable_num * ratio / min_speed
return max_dur, min_dur
\ No newline at end of file
import torch
from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper
class TypicalLogitsWarper(BaseTypicalLogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
import os
import re
import random
import torch
import torchaudio
MATPLOTLIB_FLAG = False
def load_audio(audiopath, sampling_rate):
audio, sr = torchaudio.load(audiopath)
#print(f"wave shape: {audio.shape}, sample_rate: {sr}")
if audio.size(0) > 1: # mix to mono
audio = audio[0].unsqueeze(0)
if sr != sampling_rate:
try:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
except Exception as e:
print(f"Warning: {audiopath}, wave shape: {audio.shape}, sample_rate: {sr}")
return None
# clip audio invalid values
audio.clip_(-1, 1)
return audio
def tokenize_by_CJK_char(line: str) -> str:
"""
Tokenize a line of text with CJK char.
Note: All return charaters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
Args:
line:
The input text.
Return:
A new string tokenize by CJK char.
"""
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
pattern = re.compile(
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
)
chars = pattern.split(line.strip().upper())
return " ".join([w.strip() for w in chars if w.strip()])
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment