"src/targets/vscode:/vscode.git/clone" did not exist on "9d12476ecf77d0542778175a40e23f2e353ed283"
Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from typing import Callable
import re
import inflect
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
class InspireMusicFrontEnd:
def __init__(self,
configs: Callable,
get_tokenizer: Callable,
llm_model: str,
flow_model: str,
music_tokenizer_dir: str,
audio_tokenizer_dir: str,
instruct: bool = False,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.audio_tokenizer_dir = audio_tokenizer_dir
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bandwidth_id = torch.tensor([0]).to(self.device)
self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_audio_token(self, audio, sample_rate=24000):
audio = torch.tensor(audio, dtype=torch.float32, device=self.device)
_, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id)
audio_token = audio_token.squeeze(0)
audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device)
return audio_token, audio_token_len
def text_normalize(self, text, split=True):
text = text.strip()
if contains_chinese(text):
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
text = text.replace(".", "、")
text = text.replace(" - ", ",")
text = remove_bracket(text)
text = re.sub(r'[,,]+$', '。', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
else:
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
if split is False:
return text
return texts
def frontend_text_to_music(self, text, time_start, time_end, chorus):
text_token, text_token_len = self._extract_text_token(text)
model_input = {"text": text, "audio_token": None, "audio_token_len": None,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
return model_input
def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000):
if text is None:
text_token = None
text_token_len = None
else:
text_token, text_token_len = self._extract_text_token(text)
audio_token, audio_token_len = self._extract_audio_token(audio, target_sr)
model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
return model_input
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torchaudio
import time
import logging
import argparse
from inspiremusic.cli.inspiremusic import InspireMusic
from inspiremusic.utils.file_utils import logging
import torch
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
def set_env_variables():
os.environ['PYTHONIOENCODING'] = 'UTF-8'
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
current_working_dir = os.getcwd()
main_root = os.path.realpath(os.path.join(current_working_dir, '../../'))
bin_dir = os.path.join(main_root, 'inspiremusic')
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
os.environ['PYTHONPATH'] = python_path
sys.path.extend([main_root, third_party_matcha_tts_path])
class InspireMusicUnified:
def __init__(self,
model_name: str = "InspireMusic-1.5B-Long",
model_dir: str = None,
min_generate_audio_seconds: float = 10.0,
max_generate_audio_seconds: float = 30.0,
sample_rate: int = 24000,
output_sample_rate: int = 48000,
load_jit: bool = True,
load_onnx: bool = False,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
gpu: int = 0,
result_dir: str = None,
hub="modelscope"):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
# Set model_dir or default to downloading if it doesn't exist
if model_dir is None:
model_dir = f"../../pretrained_models/{model_name}"
if not os.path.isfile(f"{model_dir}/llm.pt"):
if hub == "modelscope":
from modelscope import snapshot_download
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
self.model_dir = model_dir
self.sample_rate = sample_rate
self.output_sample_rate = 24000 if fast else output_sample_rate
self.result_dir = result_dir or f"exp/{model_name}"
os.makedirs(self.result_dir, exist_ok=True)
self.min_generate_audio_seconds = min_generate_audio_seconds
self.max_generate_audio_seconds = max_generate_audio_seconds
self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds)
self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds)
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
use_cuda = gpu >= 0 and torch.cuda.is_available()
self.device = torch.device('cuda' if use_cuda else 'cpu')
self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, dtype=dtype, fast=fast, fp16=fp16)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@torch.inference_mode()
def inference(self,
task: str = 'text-to-music',
text: str = None,
audio_prompt: str = None, # audio prompt file path
instruct: str = None,
chorus: str = "verse",
time_start: float = 0.0,
time_end: float = 30.0,
output_fn: str = "output_audio",
max_audio_prompt_length: float = 5.0,
fade_out_duration: float = 1.0,
output_format: str = "wav",
fade_out_mode: bool = True,
trim: bool = False,
):
with torch.no_grad():
text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>"
chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4}
chorus = chorus_dict.get(chorus, 1)
chorus = torch.tensor([chorus], dtype=torch.int).to(self.device)
time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device)
time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device)
music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}')
bench_start = time.time()
if task == 'text-to-music':
model_input = {
"text" : text,
"audio_prompt" : audio_prompt,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
elif task == 'continuation':
if audio_prompt is not None:
audio, _ = process_audio(audio_prompt, self.sample_rate)
if audio.size(1) < self.sample_rate:
logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.")
audio = None
else:
max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate)
audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio
model_input = {
"text" : text,
"audio_prompt" : audio,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
music_audios = []
for model_output in self.model.cli_inference(**model_input):
music_audios.append(model_output['music_audio'])
bench_end = time.time()
if trim:
music_audio = trim_audio(music_audios[0],
sample_rate=self.output_sample_rate,
threshold=0.05,
min_silence_duration=0.8)
else:
music_audio = music_audios[0]
if music_audio.shape[0] != 0:
if music_audio.shape[1] > self.max_generate_audio_length:
music_audio = music_audio[:, :self.max_generate_audio_length]
if music_audio.shape[1] >= self.min_generate_audio_length:
try:
if fade_out_mode:
music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration)
music_audio = music_audio.repeat(2, 1)
if output_format in ["wav", "flac"]:
torchaudio.save(music_fn, music_audio,
sample_rate=self.output_sample_rate,
encoding="PCM_S",
bits_per_sample=24)
elif output_format in ["mp3", "m4a"]:
torchaudio.backend.sox_io_backend.save(
filepath=music_fn, src=music_audio,
sample_rate=self.output_sample_rate,
format=output_format)
else:
logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.")
except Exception as e:
logging.error(f"Error saving file: {e}")
raise
audio_duration = music_audio.shape[1] / self.output_sample_rate
rtf = (bench_end - bench_start) / audio_duration
logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
else:
logging.error(f"Generated audio length is shorter than minimum required audio length.")
if music_fn:
if os.path.exists(music_fn):
logging.info(f"Generated audio file {music_fn} is saved.")
return music_fn
else:
logging.error(f"{music_fn} does not exist.")
def get_args():
parser = argparse.ArgumentParser(description='Run inference with your model')
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
help='Model name')
parser.add_argument('-d', '--model_dir',
help='Model folder path')
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
help='Prompt text')
parser.add_argument('-a', '--audio_prompt', default=None,
help='Prompt audio')
parser.add_argument('-c', '--chorus', default="intro",
help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
parser.add_argument('-f', '--fast', type=bool, default=False,
help='Enable fast inference mode (without flow matching)')
parser.add_argument('-g', '--gpu', type=int, default=0,
help='GPU ID for this rank, -1 for CPU')
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
help='Directory to save generated audio')
parser.add_argument('-o', '--output_fn', default="output_audio",
help='Output file name')
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
help='Format of output audio')
parser.add_argument('--sample_rate', type=int, default=24000,
help='Sampling rate of input audio')
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
help='Sampling rate of generated output audio')
parser.add_argument('-s', '--time_start', type=float, default=0.0,
help='Start time in seconds')
parser.add_argument('-e', '--time_end', type=float, default=30.0,
help='End time in seconds')
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
help='Maximum audio prompt length in seconds')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
help='Minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0,
help='Maximum generated audio length in seconds')
parser.add_argument('--fp16', type=bool, default=True,
help='Inference with fp16 model')
parser.add_argument('--fade_out', type=bool, default=True,
help='Apply fade out effect to generated audio')
parser.add_argument('--fade_out_duration', type=float, default=1.0,
help='Fade out duration in seconds')
parser.add_argument('--trim', type=bool, default=False,
help='Trim the silence ending of generated audio')
args = parser.parse_args()
if not args.model_dir:
args.model_dir = os.path.join("../../pretrained_models", args.model_name)
print(args)
return args
def main():
set_env_variables()
args = get_args()
model = InspireMusicUnified(model_name = args.model_name,
model_dir = args.model_dir,
min_generate_audio_seconds = args.min_generate_audio_seconds,
max_generate_audio_seconds = args.max_generate_audio_seconds,
sample_rate = args.sample_rate,
output_sample_rate = args.output_sample_rate,
load_jit = True,
load_onnx = False,
dtype="fp16",
fast = args.fast,
fp16 = args.fp16,
gpu = args.gpu,
result_dir = args.result_dir)
model.inference(task = args.task,
text = args.text,
audio_prompt = args.audio_prompt,
chorus = args.chorus,
time_start = args.time_start,
time_end = args.time_end,
output_fn = args.output_fn,
max_audio_prompt_length = args.max_audio_prompt_length,
fade_out_duration = args.fade_out_duration,
output_format = args.format,
fade_out_mode = args.fade_out,
trim = args.trim)
if __name__ == "__main__":
main()
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from inspiremusic.cli.frontend import InspireMusicFrontEnd
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.utils.file_utils import logging
import torch
class InspireMusic:
def __init__(self, model_dir, load_jit=True, load_onnx=False, dtype = "fp16", fast = False, fp16=True, hub="modelscope"):
instruct = True if '-Instruct' in model_dir else False
if model_dir is None:
model_dir = f"../../pretrained_models/InspireMusic-1.5B-Long"
if not os.path.isfile(f"{model_dir}/llm.pt"):
model_name = model_dir.split("/")[-1]
if hub == "modelscope":
from modelscope import snapshot_download
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f)
self.frontend = InspireMusicFrontEnd(configs,
configs['get_tokenizer'],
'{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/music_tokenizer/'.format(model_dir),
'{}/wavtokenizer/'.format(model_dir),
instruct,
dtype,
fast,
fp16,
configs['allowed_special'])
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/music_tokenizer/'.format(model_dir),
'{}/wavtokenizer/model.pt'.format(model_dir))
del configs
@torch.inference_mode()
def inference(self, task, text, audio, time_start, time_end, chorus, stream=False, sr=24000):
if task == "text-to-music":
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_text_to_music(i, time_start, time_end, chorus)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
elif task == "continuation":
if text is None:
if audio is not None:
for i in tqdm(audio):
model_input = self.frontend.frontend_continuation(None, i, time_start, time_end, chorus, sr, max_audio_length)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.continuation_inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
else:
if audio is not None:
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_continuation(i, audio, time_start, time_end, chorus, sr, max_audio_length)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.continuation_inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
else:
print("Please input text or audio.")
else:
print("Currently only support text-to-music and music continuation tasks.")
@torch.inference_mode()
def cli_inference(self, text, audio_prompt, time_start, time_end, chorus, task, stream=False, duration_to_gen=30, sr=24000):
if task == "text-to-music":
model_input = self.frontend.frontend_text_to_music(text, time_start, time_end, chorus)
logging.info('prompt text {}'.format(text))
elif task == "continuation":
model_input = self.frontend.frontend_continuation(text, audio_prompt, time_start, time_end, chorus, sr)
logging.info('prompt audio length: {}'.format(len(audio_prompt)))
start_time = time.time()
for model_output in self.model.inference(**model_input, duration_to_gen=duration_to_gen, task=task):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
@torch.inference_mode()
def inference_zero_shot(self, text, prompt_text, prompt_audio_16k, stream=False, sr=24000):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_audio_16k)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
audio_len = model_output['music_audio'].shape[1] / sr
logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
yield model_output
start_time = time.time()
@torch.inference_mode()
def inference_instruct(self, text, spk_id, instruct_text, stream=False, sr=24000):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
audio_len = model_output['music_audio'].shape[1] / sr
logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
yield model_output
start_time = time.time()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import threading
import time
from contextlib import nullcontext
import uuid
from inspiremusic.utils.common import DTYPES
from inspiremusic.music_tokenizer.vqvae import VQVAE
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
from torch.cuda.amp import autocast
import logging
import torch
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class InspireMusicModel:
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
music_tokenizer: torch.nn.Module,
wavtokenizer: torch.nn.Module,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.dtype = DTYPES.get(dtype, torch.float32)
self.llm = llm.to(self.dtype)
self.flow = flow
self.music_tokenizer = music_tokenizer
self.wavtokenizer = wavtokenizer
self.fp16 = fp16
self.token_min_hop_len = 100
self.token_max_hop_len = 200
self.token_overlap_len = 20
# mel fade in out
self.mel_overlap_len = 34
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
self.source_cache_len = int(self.mel_cache_len * 256)
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.music_token_dict = {}
self.llm_end_dict = {}
self.mel_overlap_dict = {}
self.fast = fast
self.generator = "hifi"
def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
if llm_model is not None:
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.to(self.device).to(self.dtype).eval()
else:
self.llm = None
if flow_model is not None:
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval()
if hift_model is not None:
if ".pt" not in hift_model:
self.music_tokenizer = VQVAE( hift_model + '/config.json',
hift_model + '/model.pt', with_encoder=True)
else:
self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json',
hift_model, with_encoder=True)
self.music_tokenizer.to(self.device).eval()
if wavtokenizer_model is not None:
if ".pt" not in wavtokenizer_model:
self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml',
wavtokenizer_model + '/model.pt')
else:
self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml',
wavtokenizer_model )
self.wavtokenizer.to(self.device)
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model)
self.flow.encoder = flow_encoder
def load_onnx(self, flow_decoder_estimator_model):
import onnxruntime
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
del self.flow.decoder.estimator
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
with self.llm_context:
local_res = []
with autocast(enabled=self.fp16, dtype=self.dtype, cache_enabled=True):
inference_kwargs = {
'text': text.to(self.device),
'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
'prompt_text': prompt_text.to(self.device),
'prompt_text_len': torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
'prompt_audio_token': llm_prompt_audio_token.to(self.device),
'prompt_audio_token_len': torch.tensor([llm_prompt_audio_token.shape[1]], dtype=torch.int32).to(self.device),
'embeddings': embeddings,
'duration_to_gen': duration_to_gen,
'task': task
}
if audio_token is not None:
inference_kwargs['audio_token'] = audio_token.to(self.device)
else:
inference_kwargs['audio_token'] = torch.Tensor([0]).to(self.device)
if audio_token_len is not None:
inference_kwargs['audio_token_len'] = audio_token_len.to(self.device)
else:
inference_kwargs['audio_token_len'] = torch.Tensor([0]).to(self.device)
for i in self.llm.inference(**inference_kwargs):
local_res.append(i)
self.music_token_dict[uuid] = local_res
self.llm_end_dict[uuid] = True
# def token2wav(self, token, token_len, text, text_len, uuid, sample_rate, finalize=False):
def token2wav(self, token, token_len, uuid, sample_rate, finalize=False, flow_cfg=None):
# if self.flow is not None:
# if isinstance(self.flow,MaskedDiffWithText):
# codec_embed = self.flow.inference(token=token.to(self.device),
# token_len=token_len.to(self.device),
# text_token=text,
# text_token_len=text_len,
# )
# else:
if flow_cfg is not None:
codec_embed = self.flow.inference_cfg(token=token.to(self.device),
token_len=token_len.to(self.device),
sample_rate=sample_rate
)
else:
codec_embed = self.flow.inference(token=token.to(self.device),
token_len=token_len.to(self.device),
sample_rate=sample_rate
)
# use music_tokenizer decoder
wav = self.music_tokenizer.generator(codec_embed)
wav = wav.squeeze(0).cpu().detach()
return wav
def acoustictoken2wav(self, token):
# use music_tokenizer to generate waveform from token
token = token.view(token.size(0), -1, 4)
# codec = token.view(1, -1, 4)
codec_embed = self.music_tokenizer.quantizer.embed(torch.tensor(token).long().to(self.device)).cuda()
wav = self.music_tokenizer.generator(codec_embed)
wav = wav.squeeze(0).cpu().detach()
return wav
def semantictoken2wav(self, token):
# fast mode, use wavtokenizer decoder
new_tensor = torch.tensor(token.to(self.device)).unsqueeze(0)
features = self.wavtokenizer.codes_to_features(new_tensor)
bandwidth_id = torch.tensor([0]).to(self.device)
wav = self.wavtokenizer.to(self.device).decode(features, bandwidth_id=bandwidth_id)
wav = wav.cpu().detach()
return wav
@torch.inference_mode()
def inference(self, text, audio_token, audio_token_len, text_token, text_token_len, embeddings=None,
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_audio_feat=torch.zeros(1, 0, 80), sample_rate=48000, duration_to_gen = 30, task="continuation", trim = True, stream=False, **kwargs):
# this_uuid is used to track variables related to this inference thread
# support tasks:
# text to music task
# music continuation task
# require either audio input only or text and audio inputs
this_uuid = str(uuid.uuid1())
if self.llm:
with self.lock:
self.music_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
p = threading.Thread(target=self.llm_job, args=(text_token, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, this_uuid, duration_to_gen, task))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.music_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_music_audio = self.token2wav(token=text_token,
token_len=text_token_len,
uuid=this_uuid,
sample_rate=sample_rate,
finalize=False)
yield {'music_audio': this_music_audio.cpu()}
with self.lock:
self.music_token_dict[this_uuid] = self.music_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better audio quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.music_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_music_audio = self.token2wav(token=this_music_token,
prompt_token=flow_prompt_audio_token,
prompt_feat=prompt_audio_feat,
embedding=flow_embedding,
uuid=this_uuid,
sample_rate=sample_rate,
finalize=True)
yield {'music_audio': this_music_audio.cpu()}
else:
# deal with all tokens
if self.fast:
if task == "reconstruct":
assert audio_token is None
this_music_token = audio_token
this_music_audio = self.acoustictoken2wav(token=this_music_token)
else:
if self.llm:
p.join()
print(len(self.music_token_dict[this_uuid]))
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
print(this_music_token.shape)
else:
this_music_token = text_token
logging.info("using wavtokenizer generator without flow matching")
this_music_audio = self.semantictoken2wav(token=this_music_token)
print(this_music_audio.shape)
else:
if self.llm:
p.join()
if len(self.music_token_dict[this_uuid]) != 0:
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
else:
print(f"The list of tensors is empty for UUID: {this_uuid}")
else:
this_music_token = text_token
logging.info(f"LLM generated audio token length: {this_music_token.shape[1]}")
logging.info(f"using flow matching and {self.generator} generator")
if self.generator == "hifi":
if (embeddings[1] - embeddings[0]) <= duration_to_gen:
if trim:
trim_length = (int((embeddings[1] - embeddings[0])*75))
this_music_token = this_music_token[:, :trim_length]
logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
elif (embeddings[1] - embeddings[0]) < 1:
logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
this_music_audio = self.token2wav(token=this_music_token,
token_len=torch.LongTensor([this_music_token.size(1)]),
uuid=this_uuid,
sample_rate=sample_rate,
finalize=True)
logging.info(f"Generated audio sequence length: {this_music_audio.shape[1]}")
elif self.generator == "wavtokenizer":
if (embeddings[1] - embeddings[0]) < duration_to_gen:
if trim:
trim_length = (int((embeddings[1] - embeddings[0])*75))
this_music_token = this_music_token[:,:trim_length]
logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
elif (embeddings[1] - embeddings[0]) < 1:
logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
this_music_audio = self.semantictoken2wav(token=this_music_token)
yield {'music_audio': this_music_audio.cpu()}
torch.cuda.synchronize()
\ No newline at end of file
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import json
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from inspiremusic.utils.file_utils import read_lists, read_json_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
# force datalist even
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
if len(data) < self.world_size:
print(len(data), self.world_size)
data = data * math.ceil(self.world_size / len(data))
data = data[:self.world_size]
data = data[self.rank::self.world_size]
if len(data) < self.num_workers:
data = data * math.ceil(self.num_workers / len(data))
data = data[:self.num_workers]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def Dataset(data_list_file,
data_pipeline,
mode='train',
shuffle=True,
partition=True
):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
assert mode in ['train', 'inference', 'processing']
lists = read_lists(data_list_file)
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np
import re
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2,
"outro": 4}
metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$')
timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$')
def parquet_opener(data, mode='train', audio_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in df.index:
sample.update(dict(df.loc[i]))
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def clean_lyrics(data, mode="train"):
for sample in data:
lyrics = sample["text"]
cleaned = []
for line in lyrics.splitlines():
if metadata_pattern.match(line):
continue
timestamp_match = timestamp_pattern.match(line)
if timestamp_match:
lyric = timestamp_match.group(1).strip()
if lyric:
cleaned.append(lyric)
else:
if line.strip():
cleaned.append(line.strip())
sample["text"] = '\n'.join(cleaned)
yield sample
def cut_by_length(data, max_length=8000, num_times=4, mode="train"):
for sample in data:
if "semantic_token" in sample:
sample["semantic_token"] = [
sample["semantic_token"][0][:max_length]]
if "acoustic_token" not in sample:
sample["acoustic_token"] = sample["speech_token"]
sample["acoustic_token"] = sample["acoustic_token"][
:max_length * num_times]
yield sample
def filter(data,
max_length=22500, # 22500 #5min #10240
max_acoustic_length=45000,
min_length=10,
min_acoustic_length=150,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if mode == "train":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['speech_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
if new_sample_frames > max_length or new_sample_frames < min_length:
print(f"skipped 1 item length={new_sample_frames}")
continue
sample["chorus"] = sample["chorus"].split(",")
if not isinstance(sample["time_start"], np.ndarray):
sample["time_start"] = [sample["time_start"]]
sample["time_end"] = [sample["time_end"]]
for i, t in enumerate(sample["chorus"]):
if sample["chorus"][i] == "verse":
sample["chorus"][i] = "verse1"
yield sample
if mode == "train_flow":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
if "acoustic_token" in sample:
target_sample_frames = sample['acoustic_token'][0].shape[0]
if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length:
print(
f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}")
continue
yield sample
elif mode == "inference":
for sample in data:
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['audio']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length -
waveform.shape[1])],
dim=1)
sample['audio'] = waveform
yield sample
def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train',
n_codebook=4):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'semantic_token' in sample
# TODO: unify data processing key names
if 'acoustic_token' not in sample:
continue
if 'sample_rate' in sample.keys():
sample_rate = sample['sample_rate']
else:
sample_rate = 24000
token = np.array(sample['semantic_token'][0][:-1])
# Calculate the repetition factor for resampling
repetition_factor = int(n_codebook * resample_rate / sample_rate)
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['semantic_token'] = np.array(
[np.repeat(token, repetition_factor)])
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'],
dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'],
dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'],
dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'],
dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'],
allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
if sample["chorus"] == "verse":
sample["chorus"] = "verse1"
if sample["acoustic_token"].shape[0] == 1:
sample["acoustic_token"] = np.concatenate(
sample["acoustic_token"][0])
else:
sample["acoustic_token"] = np.concatenate(sample["acoustic_token"])
sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"])
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=32):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
data_empty = True
for sample in data:
data_empty = False
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if data_empty:
raise ValueError("data is empty")
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'acoustic_token' in sample
assert isinstance(sample['acoustic_token'], torch.Tensor)
if 'semantic_token' in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['semantic_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
if len(buf) > 0:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000,
mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
elif mode == 'processing':
return static_batch(data, batch_size)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
if mode == "train":
for sample in data:
assert isinstance(sample, list)
if len(sample) != 0:
acoustic_feat_len = torch.tensor(
[x['acoustic_token'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(acoustic_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
order]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']).long() for i
in order]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in order])
time_end = torch.tensor([sample[i]['time_end'] for i in order])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor(
[CHORUS[sample[i]['chorus']] for i in order])
else:
chorus = [
torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in order]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [
torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in order]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token],
dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch
else:
logging.info("WARNING: sample is empty []!")
elif mode == "inference":
for sample in data:
assert isinstance(sample, list)
utts = [sample[i]['utt'] for i in range(len(sample))]
text = [sample[i]['text'] for i in range(len(sample))]
text_token = [torch.tensor(sample[i]['text_token']).long() for i in
range(len(sample))]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in range(len(sample))])
time_end = torch.tensor(
[sample[i]['time_end'] for i in range(len(sample))])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in
range(len(sample))])
else:
chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in range(len(sample))]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
if "acoustic_token" in sample[0]:
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
range(len(sample))]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
else:
acoustic_token = None
acoustic_token_len = None
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in
range(len(sample))]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token], dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import pack, rearrange, repeat
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment