Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp
import torch
from ..binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
roundoff: float = 1e-8, min_range: int = 2,
check: bool = True) -> torch.Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2 ** total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: torch.Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2 ** self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
return outs
def flush(self):
"""Flush the remaining information to the stream.
"""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
self._last: tp.Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
self.current -= (b1 << self.max_bit)
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2 ** self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return (mid, low, high, self.current)
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
import warnings
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from .. import distrib
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #data不变
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
distrib.broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
# breakpoint()
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
# warnings.warn('When using RVQ in training model, first check '
# 'https://github.com/facebookresearch/encodec/issues/25 . '
# 'The bug wasn\'t fixed here for reproducibility.')
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
all_indices.append(indices)
quantized = layer.decode(indices)
residual = residual - quantized.detach()
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
class LanguageVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
# print("core_vq.py:self.layers",self.layers)
def forward(self, x, n_q: tp.Optional[int] = None):
# breakpoint() x[b,t,c] #[64,75,128]
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
# breakpoint()
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized_out, indices, loss = layer(residual) #得到该层的表征,该层的indices,该层的loss [64,75]
# residual = residual - quantized.detach()
# quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
# breakpoint()
# breakpoint()
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
all_indices.append(indices)
quantized = layer.decode(indices)
residual = residual - quantized.detach()
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import math
import typing as tp
import torch
from torch import nn
from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
# print(self.bins)
# breakpoint()
self.vq = LanguageVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
# self.vq = ResidualVectorQuantization(
# dim=self.dimension,
# codebook_size=self.bins,
# num_quantizers=self.n_q,
# decay=self.decay,
# kmeans_init=self.kmeans_init,
# kmeans_iters=self.kmeans_iters,
# threshold_ema_dead_code=self.threshold_ema_dead_code,
# )
def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
frame_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
# breakpoint()
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
# assert n_q==4
# breakpoint()
# nq_choice=[3,4,8]
nq_choice=[4,6,8]
if self.training:
# choice = int(torch.randint(0, 3, (1,)).item())
choice = int(torch.randint(0, 3, (1,)).item())
# breakpoint()
n_q=nq_choice[choice]
# breakpoint()
# n_q=8
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
frame_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
# n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
# # assert n_q==4
# # breakpoint()
# # nq_choice=[3,4,8]
# nq_choice=[3,4,5,6,7,8]
# if self.training:
# # choice = int(torch.randint(0, 3, (1,)).item())
# choice = int(torch.randint(0, 6, (1,)).item())
# # breakpoint()
# n_q=nq_choice[choice]
n_q=1
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int:
"""Return n_q based on specified target bandwidth.
"""
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.:
# bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
# bandwidth == 6.0
n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return n_q
def get_bandwidth_per_quantizer(self, frame_rate: int):
"""Return bandwidth per quantizer for a given input frame rate.
Each quantizer encodes a frame with lg(bins) bits.
"""
return math.log2(self.bins) * frame_rate
def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizers to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
codes = self.vq.encode(x, n_q=n_q)
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
"""
quantized = self.vq.decode(codes)
return quantized
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities."""
from hashlib import sha256
from pathlib import Path
import typing as tp
import torch
import torchaudio
def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
# Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
# e.g., more than 2 frames per position.
# The core idea is to use a weight function that is a triangle,
# with a maximum value at the middle of the segment.
# We use this weighting when summing the frames, and divide by the sum of weights
# for each positions at the end. Thus:
# - if a frame is the only one to cover a position, the weighting is a no-op.
# - if 2 frames cover a position:
# ... ...
# / \/ \
# / /\ \
# S T , i.e. S offset of second frame starts, T end of first frame.
# Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
# After the final normalization, the weight of the second frame at position `t` is
# (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
#
# - if more than 2 frames overlap at a given point, we hope that by induction
# something sensible happens.
assert len(frames)
device = frames[0].device
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
frame_length = frames[0].shape[-1]
t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
weight = 0.5 - (t - 0.5).abs()
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
out[..., offset:offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset:offset + frame_length] += weight[:frame_length]
offset += stride
assert sum_weight.min() > 0
return out / sum_weight
def _get_checkpoint_url(root_url: str, checkpoint: str):
if not root_url.endswith('/'):
root_url += '/'
return root_url + checkpoint
def _check_checksum(path: Path, checksum: str):
sha = sha256()
with open(path, 'rb') as file:
while True:
buf = file.read(2**20)
if not buf:
break
sha.update(buf)
actual_checksum = sha.hexdigest()[:len(checksum)]
if actual_checksum != checksum:
raise RuntimeError(f'Invalid checksum for file {path}, '
f'expected {checksum} but got {actual_checksum}')
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions"
assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo."
*shape, channels, length = wav.shape
if target_channels == 1:
wav = wav.mean(-2, keepdim=True)
elif target_channels == 2:
wav = wav.expand(*shape, target_channels, length)
elif channels == 1:
wav = wav.expand(target_channels, -1)
else:
raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}")
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
return wav
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
sample_rate: int, rescale: bool = False):
limit = 0.99
mx = wav.abs().max()
if rescale:
wav = wav * min(limit / mx, 1)
else:
wav = wav.clamp(-limit, limit)
torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
#!/bin/bash
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
remove_archive=true
shift
fi
if [ $# -ne 3 ]; then
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
echo " train-clean-100, train-clean-360, train-other-500."
exit 1
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
echo "$0: no such directory $data"
exit 1
fi
part_ok=false
list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
for x in $list; do
if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
exit 1
fi
if [ -z "$url" ]; then
echo "$0: empty URL base."
exit 1
fi
if [ -f $data/LibriTTS/$part/.complete ]; then
echo "$0: data part $part was already successfully extracted, nothing to do."
exit 0
fi
# sizes of the archive files in bytes. This is some older versions.
sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
# sizes_new is the archive file sizes of the final release. Some of these sizes are of
# things we probably won't download.
sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
if [ -f $data/$part.tar.gz ]; then
size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
size_ok=false
for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
if ! $size_ok; then
echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
echo "does not equal the size of one of the archives."
rm $data/$part.tar.gz
else
echo "$data/$part.tar.gz exists and appears to be complete."
fi
fi
if [ ! -f $data/$part.tar.gz ]; then
if ! which wget >/dev/null; then
echo "$0: wget is not installed."
exit 1
fi
full_url=$url/$part.tar.gz
echo "$0: downloading data from $full_url. This may take some time, please be patient."
if ! wget -P $data --no-check-certificate $full_url; then
echo "$0: error executing wget $full_url"
exit 1
fi
fi
if ! tar -C $data -xvzf $data/$part.tar.gz; then
echo "$0: error un-tarring archive $data/$part.tar.gz"
exit 1
fi
touch $data/LibriTTS/$part/.complete
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
if $remove_archive; then
echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
rm $data/$part.tar.gz
fi
# Copyright (c) 2024 Alibaba Inc All Rights Reserved.
import argparse
import logging
import glob
import os
from tqdm import tqdm
logger = logging.getLogger()
def main():
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
for wav in tqdm(wavs):
txt = wav.replace('.wav', '.normalized.txt')
if not os.path.exists(txt):
logger.warning('{} do not exsist'.format(txt))
continue
with open(txt) as f:
content = ''.join(l.replace('\n', '') for l in f.readline())
utt = os.path.basename(wav).replace('.wav', '')
spk = utt.split('_')[0]
utt2wav[utt] = wav
utt2text[utt] = content
utt2spk[utt] = spk
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
for k, v in utt2wav.items():
f.write('{} {}\n'.format(k, v))
with open('{}/text'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items():
f.write('{} {}\n'.format(k, v))
with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
for k, v in utt2spk.items():
f.write('{} {}\n'.format(k, v))
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v)))
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
type=str)
args = parser.parse_args()
main()
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../:../../third_party/Matcha-TTS:$PYTHONPATH
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../`
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export BIN_DIR=${MAIN_ROOT}/inspiremusic
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script just show an example to build your own music generation model.
# You may need to prepare your own dataset to fine-tune or train from scratch.
# Here take MusicCaps [1] dataset as an example.
# Download MusicCaps from: https://huggingface.co/datasets/google/MusicCaps
# Reference:
# 1. Agostinelli, A., Denk, T. I., Borsos, Z., Engel, J., Verzetti, M., Caillon, A., Huang, Q., Jansen, A., Roberts, A., Tagliasacchi, M., Sharifi, M., Zeghidour, N., & Frank, C. (2023). MusicLM: Generating music from text. Google Research. https://doi.org/10.48550/arXiv.2301.11325
. ./path.sh || exit 1;
stage=1
stop_stage=5
model_name=InspireMusic-Base
pretrained_model_dir=../../pretrained_models/${model_name}
dataset_name=musiccaps
# data preparation
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Download dataset and prepare wav.scp/text."
# Here you may need to download MusicCaps dataset
for x in ${dataset_name}_train ${dataset_name}_dev; do
[ -d data/$x/ ] || mkdir -p data/$x/
done
fi
export CUDA_VISIBLE_DEVICES="0"
# extract acoustic tokens
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract acoustic token, you should have prepared acoustic tokenizer model, wav.scp/text"
for x in ${dataset_name}_dev ${dataset_name}_train; do
echo "$x"
tools/extract_acoustic_token.py --dir data/$x \
--ckpt_path ${pretrained_model_dir}/music_tokenizer/model.pt \
--config_path ${pretrained_model_dir}/music_tokenizer/config.json
done
fi
# extract semantic tokens
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract semantic token, you should have prepared semantic tokenizer model, wav.scp/text"
for x in ${dataset_name}_dev ${dataset_name}_train; do
echo "$x"
tools/extract_semantic_token.py --dir data/$x \
--ckpt_path ${pretrained_model_dir}/wavtokenizer/model.pt \
--config_path ${pretrained_model_dir}/wavtokenizer/config.yaml
done
fi
# data packing
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2acoustic_token.pt/utt2semantic_token.pt"
for x in ${dataset_name}_train ${dataset_name}_dev; do
echo $x
[ -d data/$x/parquet ] || mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 10000 \
--num_processes 10 \
--semantic_token_dir `pwd`/data/$x/ \
--acoustic_token_dir `pwd`/data/$x/ \
--des_dir `pwd`/data/$x/parquet \
--src_dir `pwd`/data/$x/
done
fi
# inference
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
test_set=${dataset_name}_dev
echo "Run inference."
expr_name="${test_set}"
for task in 'text-to-music' 'continuation'; do
[ -d `pwd`/exp/${model_name}/${task}_${expr_name} ] && rm -rf `pwd`/exp/${model_name}/${task}_${expr_name}
echo `pwd`/exp/${model_name}/${task}_${expr_name}
python inspiremusic/bin/inference.py --task $task \
--gpu 0 \
--config conf/inspiremusic.yaml \
--prompt_data data/${test_set}/parquet/data.list \
--flow_model $pretrained_model_dir/flow.pt \
--llm_model $pretrained_model_dir/llm.pt \
--music_tokenizer $pretrained_model_dir/music_tokenizer \
--wavtokenizer $pretrained_model_dir/wavtokenizer \
--result_dir `pwd`/exp/${model_name}/${task}_${expr_name}
done
fi
# train llm and flow models fp16
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1024
dist_backend="nccl"
num_workers=8
prefetch=100
train_engine=torch_ddp
expr_name="InspireMusic-Base-musiccaps-ft"
echo "Run model training. We support llm and flow traning."
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/${dataset_name}_train/parquet/data.list > data/${dataset_name}_train.data.list
cat data/${dataset_name}_dev/parquet/data.list > data/${dataset_name}_dev.data.list
# train llm, support fp16 training
model="llm"
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
inspiremusic/bin/train.py \
--train_engine $train_engine \
--config conf/inspiremusic.yaml \
--train_data data/${dataset_name}_train.data.list \
--cv_data data/${dataset_name}_dev.data.list \
--model $model \
--model_dir `pwd`/exp/${expr_name}/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/${expr_name}/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer \
--fp16 \
--checkpoint ../../pretrained_models/InspireMusic-Base/llm.pt
# train flow matching model, only support fp32 training
model="flow"
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
inspiremusic/bin/train.py \
--train_engine $train_engine \
--config conf/inspiremusic.yaml \
--train_data data/${dataset_name}_train.data.list \
--cv_data data/${dataset_name}_dev.data.list \
--model $model \
--model_dir `pwd`/exp/${expr_name}/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/${expr_name}/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
fi
# 1、text-to-music task
# with flow matching
# use one-line command to get a quick try
python -m inspiremusic.cli.inference
# custom the config like the following one-line command
# python -m inspiremusic.cli.inference --task text-to-music -m "InspireMusic-1.5B-Long" -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." -c intro -s 0.0 -e 30.0 -r "exp/inspiremusic" -o output -f wav
# without flow matching
# python -m inspiremusic.cli.inference --task text-to-music -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." --fast True
# 2、Music Continuation Task
# with flow matching
# python -m inspiremusic.cli.inference --task continuation -g 0 -a exp/inspiremusic/output_audio.wav
# without flow matching
# python -m inspiremusic.cli.inference --task continuation -g 0 -a exp/inspiremusic/output_audio.wav --fast True
from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
set_env_variables()
model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
model.inference("text-to-music", "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import numpy as np
import torchaudio
from inspiremusic.utils.audio_utils import normalize, split_wav_into_chunks
from inspiremusic.music_tokenizer.vqvae import VQVAE
import time
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main(args):
audio_min_length = 1.0
audio_max_length = 30.0
max_chunk_size = int(args.sample_rate * audio_max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
model = VQVAE(args.config_path, args.ckpt_path, with_encoder=True)
model.cuda()
model.eval()
utt2acoustic_token = {}
start_time = time.time()
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != args.sample_rate:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=args.sample_rate)(audio)
audio_length = audio.shape[1]
if audio_length > args.sample_rate * audio_min_length:
if audio_length > max_chunk_size:
wav_chunks = split_wav_into_chunks(audio_length, audio, max_chunk_size)
for chunk in wav_chunks:
chunk = torch.tensor(chunk, dtype=torch.float32).to(device)
acoustic_token = model.encode(chunk)
if acoustic_token.is_cuda:
acoustic_token = acoustic_token.cpu()
acoustic_token = acoustic_token.numpy().astype(np.int16)
if utt not in utt2acoustic_token.keys():
utt2acoustic_token[utt] = acoustic_token
else:
utt2acoustic_token[utt] = np.concatenate((utt2acoustic_token[utt], acoustic_token), axis=1)
else:
audio = torch.tensor(audio, dtype=torch.float32).to(device)
acoustic_token = model.encode(audio)
if acoustic_token.is_cuda:
acoustic_token = acoustic_token.cpu()
acoustic_token = acoustic_token.numpy().astype(np.int16)
utt2acoustic_token[utt] = acoustic_token
else:
logging.warning('This audio length is too short.')
torch.save(utt2acoustic_token, '{}/utt2acoustic_token.pt'.format(args.dir))
logging.info('spend time {}'.format(time.time() - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--config_path',
type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/config.json")
parser.add_argument('--ckpt_path',
type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/model.pt")
parser.add_argument('--sample_rate',
default=24000,
type=int)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
def main(args):
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import numpy as np
import torchaudio
import time
import os
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
from inspiremusic.utils.audio_utils import split_wav_into_chunks
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main(args):
audio_min_length = 1.0
audio_max_length = 30.0
max_chunk_size = int(args.sample_rate * audio_max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
wavtokenizer = WavTokenizer.from_pretrained_feat(args.config_path, args.ckpt_path).to(device)
bandwidth_id = torch.tensor([0]).to(device)
start_time = time.time()
utt2semantic_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != args.sample_rate:
audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=args.sample_rate)
audio_length = audio.shape[1]
if audio_length > args.sample_rate * audio_min_length:
if audio_length > max_chunk_size:
wav_batch = split_wav_into_chunks(audio_length, audio, max_chunk_size)
for chunk in wav_batch:
chunk = torch.tensor(chunk, dtype=torch.float32).to(device)
_, semantic_token = wavtokenizer.encode_infer(chunk, bandwidth_id=bandwidth_id)
if semantic_token.is_cuda:
semantic_token = semantic_token.cpu()
semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16)
if utt not in utt2semantic_token.keys():
utt2semantic_token[utt] = semantic_token
else:
utt2semantic_token[utt] = np.concatenate((utt2semantic_token[utt], semantic_token), axis=1)
else:
audio = torch.tensor(audio, dtype=torch.float32).to(device)
_, semantic_token = wavtokenizer.encode_infer(audio, bandwidth_id=bandwidth_id)
if semantic_token.is_cuda:
semantic_token = semantic_token.cpu()
semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16)
utt2semantic_token[utt] = semantic_token
else:
logging.warning('This audio length is too short.')
torch.save(utt2semantic_token, '{}/utt2semantic_token.pt'.format(args.dir))
logging.info('spend time {}'.format(time.time() - start_time))
def reconstruct(semantic_token_file, config_path, ckpt_path, outdir, sample_rate=24000):
if not os.path.isdir(outdir):
os.makedirs(outdir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bandwidth_id = torch.tensor([0]).to(device)
wavtokenizer = WavTokenizer.from_pretrained_feat(config_path, ckpt_path).to(device)
utt2semantic_token = torch.load(semantic_token_file)
for utt in tqdm(utt2semantic_token.keys()):
token = utt2semantic_token[utt]
new_tensor = torch.tensor(token).to(device).unsqueeze(0)
features = wavtokenizer.codes_to_features(new_tensor)
wav = wavtokenizer.decode(features, bandwidth_id=bandwidth_id)
wav = wav.cpu().detach()
torchaudio.save(outdir + "/" + utt + ".wav", wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--config_path',
type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/config.yaml")
parser.add_argument('--ckpt_path',
type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/model.pt")
parser.add_argument('--sample_rate',
default=24000,
type=int)
parser.add_argument('--outwavdir',
type=str, default="./exp/wavs")
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import onnxruntime
import numpy as np
import torchaudio
import whisper
def main(args):
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2speech_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import json
from tqdm import tqdm
import pandas as pd
import multiprocessing
import time
import torch
import numpy as np
import random
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def job(utt_list, token_list, parquet_file, utt2text, utt2time, utt2chorus, semantic_token_list):
start_time = time.time()
text_list = [utt2text[utt] for utt in utt_list]
time_start = [utt2time[utt][0] for utt in utt_list]
time_end = [utt2time[utt][1] for utt in utt_list]
chorus_list = [utt2chorus[utt] for utt in utt_list]
print(len(token_list))
print(len(semantic_token_list))
try:
df = pd.DataFrame()
df['utt'] = utt_list
df['text'] = text_list
df['chorus'] = chorus_list
df['time_start'] = time_start
df['time_end'] = time_end
df["semantic_token"] = semantic_token_list
df["acoustic_token"] = token_list
logging.info(f'Starting to save parquet file: {parquet_file}')
df.to_parquet(parquet_file)
logging.info(f'Successfully saved parquet file: {parquet_file}')
except Exception as e:
logging(f'Error saving parquet file: {e}')
logging.info('Processing time {}s'.format(time.time() - start_time))
def text_only_job(utt_list, parquet_file, utt2text, utt2time, utt2chorus):
start_time = time.time()
text_list = [utt2text[utt] for utt in utt_list]
time_start = [utt2time[utt][0] for utt in utt_list]
time_end = [utt2time[utt][1] for utt in utt_list]
chorus_list = [utt2chorus[utt] for utt in utt_list]
try:
# 保存到parquet
df = pd.DataFrame()
df['utt'] = utt_list
df['text'] = text_list
df['chorus'] = chorus_list
df['time_start'] = time_start
df['time_end'] = time_end
logging.info(f'Starting to save parquet file: {parquet_file}')
df.to_parquet(parquet_file)
logging.info(f'Successfully saved parquet file: {parquet_file}')
except Exception as e:
logging(f'Error saving parquet file: {e}')
logging.info('Processing time {}s'.format(time.time() - start_time))
def parse_trans(line):
music_structure_labels = ["intro", "verse1", "chorus", "verse2", "verse", "outro"]
uid,l = line.strip().split("\t")
split = l.split("|><|")
time_start = float(split[0].replace("<|",""))
time_end = float(split[-1].replace("|>", ""))
chorus = split[1]
if split[2] == "lyrics":
text = "<|lyrics|> " + split[3]
elif split[2] == "music":
text = "<|music|>"
else:
text = split[2]
if chorus not in music_structure_labels:
chorus = random.choice(music_structure_labels)
if chorus in ["verse1", "verse2"]:
chorus = "verse"
if len(split) < 4 or time_start >= time_end:
print(line, split, time_start, time_end)
return None
if time_start < 0:
time_start = 0.0
return (uid, time_start, time_end, chorus, text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_utts_per_parquet',
type=int,
default=1000,
required=False,
help='num utts per parquet')
parser.add_argument('--num_processes',
type=int,
default=1,
required=False,
help='num processes for make parquets')
parser.add_argument('--src_dir',
type=str, required=True)
parser.add_argument('--des_dir',
type=str, required=True)
parser.add_argument('--semantic_token_dir',
type=str,
default=None, required=False)
parser.add_argument('--acoustic_token_dir',
type=str,
default=None, required=False)
args = parser.parse_args()
parquet_list = []
cnt = 0
utt2text = {}
utt2time = {}
utt2chorus = {}
uid_list = []
print(args)
if not os.path.exists(f'{args.src_dir}/text'):
raise FileNotFoundError(
f"Please check: {args.src_dir}/text file does not exist")
with open(f'{args.src_dir}/text', 'r') as f:
for l in f:
res = parse_trans(l)
if res is None:
continue
uid, time_start, time_end, chorus, text = res
uid_list.append(uid)
utt2time[uid] = (time_start, time_end)
utt2chorus[uid] = chorus
utt2text[uid] = text
utt2semantic_token = None
utt2acoustic_token = None
if args.semantic_token_dir is not None:
utt2semantic_token = {}
for fn in os.listdir(args.semantic_token_dir):
if fn.endswith("pt") and fn.startswith("utt2semantic_"):
print(f"Starting {fn}")
try:
utt2semantic_token.update(
torch.load('{}/{}'.format(args.semantic_token_dir, fn)))
except:
print('{}/{} failed'.format(args.semantic_token_dir, fn))
pass
print(len(utt2semantic_token))
# # Using process pool to speedup
pool = multiprocessing.Pool(processes=args.num_processes)
if args.acoustic_token_dir is not None:
for fn in os.listdir(args.acoustic_token_dir):
if fn.endswith("pt") and fn.startswith("utt2acoustic_"):
print(f"Starting {fn}")
utt2token = torch.load(
'{}/{}'.format(args.acoustic_token_dir, fn))
utts = [utt for utt in utt2token.keys() if utt in utt2text.keys()]
if utt2semantic_token:
utts = [utt for utt in utts if
utt in utt2semantic_token.keys()]
if len(utts) == 0:
print("0 lines remained.")
continue
if isinstance(utt2token[utts[0]], np.ndarray):
token_lists = [utt2token[utt][0].tolist() for utt in utts]
else:
token_lists = [
utt2token[utt].tolist() if utt2token[
utt].dim() == 2 else
utt2token[utt][0].tolist()
for utt in utts
]
print(len(token_lists))
semantic_token_lists = [
utt2semantic_token[utt].tolist() if not isinstance(
utt2semantic_token[utt], list) else
utt2semantic_token[utt] for utt in
utts] if utt2semantic_token else None
for i, j in enumerate(
range(0, len(utts), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir,
'parquet_{:09d}.tar'.format(
cnt + i))
print(f"process {parquet_file}")
parquet_list.append(parquet_file)
token_list = token_lists[j: j + args.num_utts_per_parquet]
if semantic_token_lists:
semantic_token_list = semantic_token_lists[
j: j + args.num_utts_per_parquet]
else:
semantic_token_list = None
pool.apply_async(job, (
utts[j: j + args.num_utts_per_parquet], token_list,
parquet_file, utt2text, utt2time, utt2chorus,
semantic_token_list))
cnt += i
if args.semantic_token_dir is None and args.acoustic_token_dir is None:
for i, j in enumerate(
range(0, len(uid_list), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir,
'parquet_{:09d}.tar'.format(cnt + i))
print(f"process {parquet_file}")
parquet_list.append(parquet_file)
pool.apply_async(text_only_job, (
uid_list[j: j + args.num_utts_per_parquet], parquet_file, utt2text,
utt2time, utt2chorus))
cnt += i
pool.close()
pool.join()
print("DONE")
with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1:
for name in parquet_list:
f1.write(name + '\n')
from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
set_env_variables()
model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
model.inference("text-to-music", "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment