Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import json
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from inspiremusic.utils.file_utils import read_lists, read_json_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
# force datalist even
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
if len(data) < self.world_size:
print(len(data), self.world_size)
data = data * math.ceil(self.world_size / len(data))
data = data[:self.world_size]
data = data[self.rank::self.world_size]
if len(data) < self.num_workers:
data = data * math.ceil(self.num_workers / len(data))
data = data[:self.num_workers]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def Dataset(data_list_file,
data_pipeline,
mode='train',
shuffle=True,
partition=True
):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
assert mode in ['train', 'inference', 'processing']
lists = read_lists(data_list_file)
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np
import re
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2,
"outro": 4}
metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$')
timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$')
def parquet_opener(data, mode='train', audio_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in df.index:
sample.update(dict(df.loc[i]))
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def clean_lyrics(data, mode="train"):
for sample in data:
lyrics = sample["text"]
cleaned = []
for line in lyrics.splitlines():
if metadata_pattern.match(line):
continue
timestamp_match = timestamp_pattern.match(line)
if timestamp_match:
lyric = timestamp_match.group(1).strip()
if lyric:
cleaned.append(lyric)
else:
if line.strip():
cleaned.append(line.strip())
sample["text"] = '\n'.join(cleaned)
yield sample
def cut_by_length(data, max_length=8000, num_times=4, mode="train"):
for sample in data:
if "semantic_token" in sample:
sample["semantic_token"] = [
sample["semantic_token"][0][:max_length]]
if "acoustic_token" not in sample:
sample["acoustic_token"] = sample["speech_token"]
sample["acoustic_token"] = sample["acoustic_token"][
:max_length * num_times]
yield sample
def filter(data,
max_length=22500, # 22500 #5min #10240
max_acoustic_length=45000,
min_length=10,
min_acoustic_length=150,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if mode == "train":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['speech_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
if new_sample_frames > max_length or new_sample_frames < min_length:
print(f"skipped 1 item length={new_sample_frames}")
continue
sample["chorus"] = sample["chorus"].split(",")
if not isinstance(sample["time_start"], np.ndarray):
sample["time_start"] = [sample["time_start"]]
sample["time_end"] = [sample["time_end"]]
for i, t in enumerate(sample["chorus"]):
if sample["chorus"][i] == "verse":
sample["chorus"][i] = "verse1"
yield sample
if mode == "train_flow":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
if "acoustic_token" in sample:
target_sample_frames = sample['acoustic_token'][0].shape[0]
if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length:
print(
f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}")
continue
yield sample
elif mode == "inference":
for sample in data:
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['audio']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length -
waveform.shape[1])],
dim=1)
sample['audio'] = waveform
yield sample
def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train',
n_codebook=4):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'semantic_token' in sample
# TODO: unify data processing key names
if 'acoustic_token' not in sample:
continue
if 'sample_rate' in sample.keys():
sample_rate = sample['sample_rate']
else:
sample_rate = 24000
token = np.array(sample['semantic_token'][0][:-1])
# Calculate the repetition factor for resampling
repetition_factor = int(n_codebook * resample_rate / sample_rate)
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['semantic_token'] = np.array(
[np.repeat(token, repetition_factor)])
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'],
dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'],
dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'],
dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'],
dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'],
allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
if sample["chorus"] == "verse":
sample["chorus"] = "verse1"
if sample["acoustic_token"].shape[0] == 1:
sample["acoustic_token"] = np.concatenate(
sample["acoustic_token"][0])
else:
sample["acoustic_token"] = np.concatenate(sample["acoustic_token"])
sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"])
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=32):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
data_empty = True
for sample in data:
data_empty = False
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if data_empty:
raise ValueError("data is empty")
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'acoustic_token' in sample
assert isinstance(sample['acoustic_token'], torch.Tensor)
if 'semantic_token' in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['semantic_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
if len(buf) > 0:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000,
mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
elif mode == 'processing':
return static_batch(data, batch_size)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
if mode == "train":
for sample in data:
assert isinstance(sample, list)
if len(sample) != 0:
acoustic_feat_len = torch.tensor(
[x['acoustic_token'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(acoustic_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
order]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']).long() for i
in order]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in order])
time_end = torch.tensor([sample[i]['time_end'] for i in order])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor(
[CHORUS[sample[i]['chorus']] for i in order])
else:
chorus = [
torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in order]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [
torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in order]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token],
dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch
else:
logging.info("WARNING: sample is empty []!")
elif mode == "inference":
for sample in data:
assert isinstance(sample, list)
utts = [sample[i]['utt'] for i in range(len(sample))]
text = [sample[i]['text'] for i in range(len(sample))]
text_token = [torch.tensor(sample[i]['text_token']).long() for i in
range(len(sample))]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in range(len(sample))])
time_end = torch.tensor(
[sample[i]['time_end'] for i in range(len(sample))])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in
range(len(sample))])
else:
chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in range(len(sample))]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
if "acoustic_token" in sample[0]:
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
range(len(sample))]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
else:
acoustic_token = None
acoustic_token_len = None
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in
range(len(sample))]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token], dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import pack, rearrange, repeat
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from inspiremusic.utils.mask import make_pad_mask
from inspiremusic.music_tokenizer.vqvae import VQVAE
class MaskedDiff(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 128,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer",
num_codebooks: int = 4
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
self.quantizer = VQVAE( f'{generator_model_dir}/config.json',
f'{generator_model_dir}/model.pt',with_encoder=True).quantizer
self.quantizer.eval()
self.num_codebooks = num_codebooks
self.cond = None
self.interpolate = False
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
audio_token = batch['acoustic_token'].to(device)
audio_token_len = batch['acoustic_token_len'].to(device)
audio_token = audio_token.view(audio_token.size(0),-1,self.num_codebooks)
if "semantic_token" not in batch:
token = audio_token[:,:,0]
token_len = (audio_token_len/self.num_codebooks).long()
else:
token = batch['semantic_token'].to(device)
token_len = batch['semantic_token_len'].to(device)
with torch.no_grad():
feat = self.quantizer.embed(audio_token)
feat_len = (audio_token_len/self.num_codebooks).long()
token = self.input_embedding(token)
h, h_lengths = self.encoder(token, token_len)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
if self.cond:
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
else:
conds = None
mask = (~make_pad_mask(feat_len)).to(h)
loss, _ = self.decoder.compute_loss(
feat,
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
None,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
sample_rate):
assert token.shape[0] == 1
token = self.input_embedding(torch.clamp(token, min=0))
h, h_lengths = self.encoder(token, token_len)
if sample_rate == 48000:
token_len = 2 * token_len
h, h_lengths = self.length_regulator(h, token_len)
# get conditions
conds = None
mask = (~make_pad_mask(token_len)).to(h)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=None,
cond=conds,
n_timesteps=10
)
return feat
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.forward_estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
torch.zeros_like(cond) if cond is not None else None
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
elif isinstance(self.estimator, onnxruntime.InferenceSession):
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
}
output = self.estimator.run(None, ort_inputs)[0]
return torch.tensor(output, dtype=x.dtype, device=x.device)
else:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mo)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
if cond is not None:
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
import torch
from torch.nn import functional as F
from inspiremusic.utils.mask import make_pad_mask
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
def forward(self, x, ylens=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
if x1.shape[1] != 0:
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
x = torch.concat([x1, x2], dim=2)
else:
x = x2
out = self.model(x).transpose(1, 2).contiguous()
return out, mel_len1 + mel_len2
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from typing import List, Optional, Tuple
from einops import rearrange
from torchaudio.transforms import Spectrogram
class MultipleDiscriminator(nn.Module):
def __init__(
self, mpd: nn.Module, mrd: nn.Module
):
super().__init__()
self.mpd = mpd
self.mrd = mrd
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
y_d_rs += this_y_d_rs
y_d_gs += this_y_d_gs
fmap_rs += this_fmap_rs
fmap_gs += this_fmap_gs
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
y_d_rs += this_y_d_rs
y_d_gs += this_y_d_gs
fmap_rs += this_fmap_rs
fmap_gs += this_fmap_gs
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
window_length: int,
num_embeddings: Optional[int] = None,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x):
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = rearrange(x, "b f t c -> b c t f")
# Split into bands
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
return x, fmap
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HIFI-GAN"""
from typing import Dict, Optional, List
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from inspiremusic.transformer.activation import Snake
from inspiremusic.utils.common import get_padding
from inspiremusic.utils.common import init_weights
"""hifigan based generator implementation.
This code is modified from https://github.com/jik876/hifi-gan
,https://github.com/kan-bayashi/ParallelWaveGAN and
https://github.com/NVIDIA/BigVGAN
"""
class ResBlock(torch.nn.Module):
"""Residual block module in HiFiGAN/BigVGAN."""
def __init__(
self,
channels: int = 512,
kernel_size: int = 3,
dilations: List[int] = [1, 3, 5],
):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
for dilation in dilations:
self.convs1.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
)
)
)
self.convs2.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
)
)
)
self.convs1.apply(init_weights)
self.convs2.apply(init_weights)
self.activations1 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs1))
])
self.activations2 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs2))
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for idx in range(len(self.convs1)):
xt = self.activations1[idx](x)
xt = self.convs1[idx](xt)
xt = self.activations2[idx](xt)
xt = self.convs2[idx](xt)
x = xt + x
return x
def remove_weight_norm(self):
for idx in range(len(self.convs1)):
remove_weight_norm(self.convs1[idx])
remove_weight_norm(self.convs2[idx])
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
@torch.no_grad()
def forward(self, f0):
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
# generate sine waveforms
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class HiFTGenerator(nn.Module):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
f0_predictor: torch.nn.Module = None,
):
super(HiFTGenerator, self).__init__()
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.m_source.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def _stft(self, x):
spec = torch.stft(
x,
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# mel+source->speech
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, f0
@torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# use cache_source to avoid glitch
if cache_source.shape[2] != 0:
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
from inspiremusic.utils.losses import tpr_loss, mel_loss
class HiFiGan(nn.Module):
def __init__(self, generator, discriminator, mel_spec_transform,
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
super(HiFiGan, self).__init__()
self.generator = generator
self.discriminator = discriminator
self.mel_spec_transform = mel_spec_transform
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
self.feat_match_loss_weight = feat_match_loss_weight
self.tpr_loss_weight = tpr_loss_weight
self.tpr_loss_tau = tpr_loss_tau
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if batch['turn'] == 'generator':
return self.forward_generator(batch, device)
else:
return self.forward_discriminator(batch, device)
def forward_generator(self, batch, device):
real_speech = batch['speech'].to(device)
pitch_feat = batch['pitch_feat'].to(device)
# 1. calculate generator outputs
generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
loss_gen, _ = generator_loss(y_d_gs)
loss_fm = feature_loss(fmap_rs, fmap_gs)
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
else:
loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
self.tpr_loss_weight * loss_tpr + loss_f0
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
def forward_discriminator(self, batch, device):
real_speech = batch['speech'].to(device)
# 1. calculate generator outputs
with torch.no_grad():
generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
# 3. calculate discriminator losses, tpr losses [Optional]
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
else:
loss_tpr = torch.zeros(1).to(device)
loss = loss_disc + self.tpr_loss_weight * loss_tpr
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from inspiremusic.utils.common import IGNORE_ID
from inspiremusic.transformer.label_smoothing_loss import LabelSmoothingLoss
from inspiremusic.utils.common import th_accuracy, DTYPES
from torch import Tensor
from math import log
from einops import rearrange, reduce, repeat
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1).to(torch.float16)
class LLM(torch.nn.Module):
def __init__(
self,
text_encoder_input_size: int,
llm_input_size: int,
llm_output_size: int,
audio_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
text_encoder_conf: Dict = None,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
frozen_input_embed: bool = False,
dtype: str = "fp16",
**kwargs,
):
super().__init__()
self.dtype = DTYPES.get(dtype, torch.float32)
self.llm_input_size = llm_input_size
self.audio_token_size = audio_token_size
# 1. build text token inputs related modules
if llm is None:
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
else:
self.text_embedding = llm.model.model.embed_tokens
if frozen_input_embed:
print("Freezing input embedding layer")
for p in self.text_embedding.parameters():
p.requires_grad = False
self.chorus_embedding = torch.nn.Embedding(5, llm_input_size) # intro, chorus, verse1, verse2 , outro
self.text_encoder_conf = text_encoder_conf
self.text_encoder = self.build_encoder(text_encoder_conf)
self.infer_cfg_ratio = kwargs.get("infer_cfg_ratio", None)
logging.info(f"infer_cfg_ratio: {self.infer_cfg_ratio}")
self.train_cfg_ratio = kwargs.get("train_cfg_ratio", None)
logging.info(f"train_cfg_ratio: {self.train_cfg_ratio}")
# 2. build audio token language model related modules
self.sos_eos = 0
self.task_id = 1
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, audio_token_size + 1)
self.criterion_ce = LabelSmoothingLoss(
size=audio_token_size + 1,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build audio token related modules
self.speech_embedding = torch.nn.Embedding(audio_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(192, llm_input_size)
self.num_codebooks = 4
# 4. sampling method
self.sampling = sampling
self.time_embedding = SinusoidalEmbedding(llm_input_size)
def cfg_dropout(self, text_token, text_token_len, p):
# Classifier-Free Guidance Dropout
B = text_token.size(0)
num_samples_to_mask = int(p * B)
if num_samples_to_mask == 0:
num_samples_to_mask = 1
indices_to_mask = torch.randperm(B, device=text_token.device)[:num_samples_to_mask]
text_token[indices_to_mask] = 0
text_token_len[indices_to_mask] = 0
return text_token, text_token_len
def build_encoder(self, encoder_conf=None):
if encoder_conf is None:
assert hasattr(self, "encoder_conf"), \
"function param encoder_conf is None and model doesn't has encoder_conf attribute either."
encoder_conf = self.encoder_conf
encoder_name = encoder_conf.pop("name", "transformer")
model = None
if encoder_name == "transformer":
from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
use_cnn_module=False,
macaron_style=False,
)
elif encoder_name == "conformer":
from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
)
elif encoder_name == "llama_encoder":
from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder
model = LlamaEncoder(
**encoder_conf,
input_size=self.input_size,
)
elif encoder_name == "qwen2":
from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
model = QwenEncoder(
**encoder_conf,
input_size=self.input_size,
)
elif encoder_name == "qwen2.5":
from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
model = QwenEncoder(
**encoder_conf,
input_size=self.input_size,
)
encoder_conf["name"] = encoder_name
return model
def encode(self,
text: torch.Tensor,
text_lengths: torch.Tensor):
if self.text_encoder is not None:
encoder_out, encoder_mask = self.text_encoder(text, text_lengths,
decoding_chunk_size=1,
num_decoding_left_chunks=-1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out = self.text_encoder_affine_layer(encoder_out)
else:
encoder_out, encoder_out_lens = text, text_lengths
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embeddings, text_token,
text_token_len, task_id_emb, audio_token,
audio_token_len, seg_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(),
batch_first=True)
audio_token = unpad_sequence(audio_token, audio_token_len.cpu(),
batch_first=True)
for i in range(len(embeddings)):
embeddings[i] = unpad_sequence(embeddings[i], seg_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0)] + [embedding[i] for embedding in embeddings] + [text_token[i], task_id_emb.squeeze(dim=0), audio_token[i]], dim=0) for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
mask = True
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
if "semantic_token" not in batch:
audio_token = batch['acoustic_token'].to(device)
audio_token_len = batch['acoustic_token_len'].to(device)
audio_token = audio_token.view(audio_token.size(0), -1, self.num_codebooks)
audio_token = audio_token[:, :, 0]
audio_token_len = (audio_token_len / self.num_codebooks).long()
else:
audio_token = batch['semantic_token'].to(device)
audio_token_len = batch['semantic_token_len'].to(device)
time_start = batch['time_start'].to(device)
time_end = batch['time_end'].to(device)
chorus = batch['chorus'].to(device)
# 1. encode text_token
if self.train_cfg_ratio > 0:
# Classifier-Free Guidance
text_token, _ = self.cfg_dropout(text_token, text_token_len, self.train_cfg_ratio)
# 2. Time Embedding & chorus embedding
text_token = self.text_embedding(text_token)
text_token, text_token_len = self.encode(text_token, text_token_len)
if mask:
time_mask = time_start != -1.0
seg_len = time_mask.sum(-1)
time_start = time_start.masked_fill(~time_mask, 0.0)
time_end = time_end.masked_fill(~time_mask, 0.0)
chorus = chorus.masked_fill(~time_mask, 0)
time_start_embed = self.time_embedding(time_start.view(-1)).to(text_token.dtype)
time_end_embed = self.time_embedding(time_end.view(-1)).to(text_token.dtype)
time_start_embed = time_start_embed.view(chorus.size(0), chorus.size(1), -1)
time_end_embed = time_end_embed.view(chorus.size(0), chorus.size(1), -1)
chorus_embed = self.chorus_embedding(chorus)
lm_target = [torch.tensor([IGNORE_ID] * (1 + 3 * seg_len[i] + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
else:
time_start_embed = self.time_embedding(time_start).to(text_token.dtype)
time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
chorus_embed = self.chorus_embedding(chorus)
lm_target = [torch.tensor(
[IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode audio_token
audio_token = self.speech_embedding(audio_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb,
[time_start_embed,
time_end_embed,
chorus_embed],
text_token,
text_token_len,
task_id_emb,
audio_token,
audio_token_len,
seg_len)
# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input.to(self.dtype), lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.audio_token_size + 1), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
ignore_eos: bool = True,
):
top_ids = self.sampling(weighted_scores, decoded_tokens)
return top_ids
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
audio_token: torch.Tensor,
audio_token_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_audio_token: torch.Tensor,
prompt_audio_token_len: torch.Tensor,
embeddings: List,
duration_to_gen: float = 30,
task: str = "continuation",
token_rate: int = 75,
limit_audio_prompt_len: int = 5,
) -> Generator[torch.Tensor, None, None]:
device = text.device
if text is not None:
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
infer_cfg = self.infer_cfg_ratio >= 0.0
if infer_cfg:
text_cfg = self.text_embedding(text.new_zeros(text.shape))
text = self.text_embedding(text)
# 1. encode text
text, text_len = self.encode(text, text_len)
# 2. encode embedding
if embeddings is not None:
time_start, time_end, chorus = embeddings
if len(chorus.shape) == 1:
time_start_embed = self.time_embedding(time_start).reshape(1, 1, -1) # .half()
time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
else:
time_start_embed = self.time_embedding(
time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
chorus_embed = self.chorus_embedding(chorus) # .half()
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if audio_token_len:
audio_token = audio_token[:, :(limit_audio_prompt_len * token_rate)]
audio_token_emb = self.speech_embedding(audio_token)
else:
audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
if prompt_audio_token_len:
prompt_audio_token_emb = self.speech_embedding(prompt_audio_token)
else:
prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
# Check if removing prompt audio token will fail decoding.
if task == "continuation":
lm_input = torch.concat(
[sos_eos_emb, time_start_embed, time_end_embed,
chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
if infer_cfg:
audio_cfg = self.speech_embedding(
audio_token.new_zeros(audio_token.shape))
lm_cf_input = torch.concat(
[sos_eos_emb, torch.rand_like(time_start_embed),
torch.rand_like(time_end_embed),
torch.rand_like(chorus_embed), text_cfg, task_id_emb,
audio_cfg], dim=1)
lm_input = torch.cat([lm_input, lm_cf_input], 0)
else:
lm_input = torch.concat(
[sos_eos_emb, time_start_embed, time_end_embed,
chorus_embed, text, task_id_emb], dim=1)
if infer_cfg:
lm_cf_input = torch.concat(
[sos_eos_emb, torch.rand_like(time_start_embed),
torch.rand_like(time_end_embed),
torch.rand_like(chorus_embed), text_cfg, task_id_emb],
dim=1)
lm_input = torch.cat([lm_input, lm_cf_input], 0)
# 4. cal min/max_length
min_len = int(0.9 * duration_to_gen * token_rate)
max_len = duration_to_gen * token_rate
# 5. step by step decode
out_tokens = []
offset = 0
state = None
for i in range(int(max_len)):
y_pred, _, state = self.llm.forward_one_step(lm_input.to(self.dtype), torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
logits = self.llm_decoder(y_pred[:, -1])
if infer_cfg:
# perform context free guidance
logits_cf = logits[1]
logits = logits[0]
infer_cfg_ratio = self.infer_cfg_ratio
logits = infer_cfg_ratio * logits + (1 - infer_cfg_ratio) * logits_cf
logp = logits.log_softmax(dim=-1)
logp = logp.squeeze(dim=0)
if i < int(min_len):
logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=self.dtype)
top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
if top_ids == self.audio_token_size:
break
# # in stream mode, yield token one by one
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
if infer_cfg:
lm_input = lm_input.repeat(2, 1, 1)
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
from tqdm import tqdm
import torch
import numpy as np
import laion_clap
from clap_module.factory import load_state_dict
import librosa
import pyloudnorm as pyln
# following documentation from https://github.com/LAION-AI/CLAP
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_audioset_epoch_15_esc_90.14.pt'):
"""
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
This evaluation script assumes that audio_path files are identified with the ids in id2text.
clap_score() evaluates all ids in id2text.
GPU-based computation.
Select one of the following models from https://github.com/LAION-AI/CLAP:
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
- music_audioset_epoch_15_esc_90.14.pt
- music_speech_epoch_15_esc_89.25.pt
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
Params:
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
-- audio_path: path where the generated audio files to evaluate are available.
-- audio_files_extension: files extension (default .wav) in eval_path.
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
Returns:
-- CLAP-LION score
"""
# load model
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
clap_path = 'CLAP/music_speech_audioset_epoch_15_esc_89.98.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
clap_path = 'CLAP/music_audioset_epoch_15_esc_90.14.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
clap_path = 'CLAP/music_speech_epoch_15_esc_89.25.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == '630k-audioset-fusion-best.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
clap_path = 'CLAP/630k-audioset-fusion-best.pt'
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
else:
raise ValueError('clap_model not implemented')
# download clap_model if not already downloaded
if not os.path.exists(clap_path):
print('Downloading ', clap_model, '...')
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(clap_path, 'wb') as file:
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
for data in response.iter_content(chunk_size=8192):
file.write(data)
progress_bar.update(len(data))
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
pkg = load_state_dict(clap_path)
pkg.pop('text_branch.embeddings.position_ids', None)
model.model.load_state_dict(pkg)
model.eval()
if not os.path.isdir(audio_path):
raise ValueError(f'audio_path: {audio_path} does not exist')
if id2text:
print('[EXTRACTING TEXT EMBEDDINGS] ')
batch_size = 64
text_emb = {}
for i in tqdm(range(0, len(id2text), batch_size)):
batch_ids = list(id2text.keys())[i:i+batch_size]
batch_texts = [id2text[id] for id in batch_ids]
with torch.no_grad():
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
for id, emb in zip(batch_ids, embeddings):
text_emb[id] = emb
else:
raise ValueError('Must specify id2text')
print('[EVALUATING GENERATIONS] ', audio_path)
score = 0
count = 0
for id in tqdm(id2text.keys()):
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
if os.path.isfile(file_path):
with torch.no_grad():
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
audio = pyln.normalize.peak(audio, -1.0)
audio = audio.reshape(1, -1) # unsqueeze (1,T)
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
print(f"{id} | CLAP score = {cosine_sim}")
score += cosine_sim
count += 1
return score / count if count > 0 else 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment