Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from inspiremusic.cli.inspiremusic import InspireMusic
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/InspireMusic',
help='local path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
# 1. export llm text_encoder
llm_text_encoder = inspiremusic.model.llm.text_encoder.half()
script = torch.jit.script(llm_text_encoder)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
# 2. export llm llm
llm_llm = inspiremusic.model.llm.llm.half()
script = torch.jit.script(llm_llm)
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
# 3. export flow encoder
flow_encoder = inspiremusic.model.flow.encoder
script = torch.jit.script(flow_encoder)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
if __name__ == '__main__':
main()
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import onnxruntime
import random
import torch
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from inspiremusic.cli.inspiremusic import InspireMusic
def get_dummy_input(batch_size, seq_len, out_channels, device):
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
t = torch.rand((batch_size), dtype=torch.float32, device=device)
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
return x, mask, mu, t, spks, cond
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/InspireMusic',
help='local path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
# 1. export flow decoder estimator
estimator = inspiremusic.model.flow.decoder.estimator
device = inspiremusic.model.device
batch_size, seq_len = 1, 256
out_channels = inspiremusic.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export(
estimator,
(x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'seq_len'},
'mask': {0: 'batch_size', 2: 'seq_len'},
'mu': {0: 'batch_size', 2: 'seq_len'},
'cond': {0: 'batch_size', 2: 'seq_len'},
't': {0: 'batch_size'},
'spks': {0: 'batch_size'},
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
}
)
# 2. test computation consistency
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)
for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
}
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
if __name__ == "__main__":
main()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.dataset.dataset import Dataset
from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
def get_args():
parser = argparse.ArgumentParser(description='inference only with flow model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--llm_model', default=None,required=False, help='llm model file')
parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
parser.add_argument('--sample_rate', type=int, default=48000, required=False,
help='sampling rate of generated audio')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
help='the minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
help='the maximum generated audio length in seconds')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init inspiremusic models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = InspireMusicModel(None, configs['flow'], configs['hift'], configs['wavtokenizer'])
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
if args.llm_model is None:
model.llm = None
else:
model.llm = model.llm.to(torch.float32)
if args.flow_model is None:
model.flow = None
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
if "semantic_token" in batch:
token = batch["semantic_token"].to(device)
token_len = batch["semantic_token_len"].to(device)
else:
if audio_token is None:
token = None
token_len = None
else:
token = audio_token.view(audio_token.size(0),-1,4)[:,:,0]
token_len = audio_token_len / 4
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
text = batch["text"]
if "time_start" not in batch.keys():
batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
if "time_end" not in batch.keys():
batch["time_end"] = torch.randint(args.min_generate_audio_seconds, args.max_generate_audio_seconds, (1,)).to(torch.float64)
elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
if "chorus" not in batch.keys():
batch["chorus"] = torch.randint(1, 5, (1,))
if args.chorus == "random":
batch["chorus"] = torch.randint(1, 5, (1,))
elif args.chorus == "intro":
batch["chorus"] = torch.Tensor([0])
elif "verse" in args.chorus:
batch["chorus"] = torch.Tensor([1])
elif args.chorus == "chorus":
batch["chorus"] = torch.Tensor([2])
elif args.chorus == "outro":
batch["chorus"] = torch.Tensor([4])
time_start = batch["time_start"].to(device)
time_end = batch["time_end"].to(device)
chorus = batch["chorus"].to(torch.int)
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
chorus = chorus.to(device)
model_input = {"text": text, "audio_token": token, "audio_token_len": token_len,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
music_audios = []
for model_output in model.inference(**model_input):
music_audios.append(model_output['music_audio'])
music_key = utts[0]
music_fn = os.path.join(args.result_dir, '{}.wav'.format(music_key))
torchaudio.save(music_fn, music_audios[0], sample_rate=args.sample_rate)
f.write('{} {}\n'.format(music_key, music_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.dataset.dataset import Dataset
import time
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser(description='inference only with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--flow_model', default=None, required=False, help='flow model file')
parser.add_argument('--llm_model', default=None,required=False, help='flow model file')
parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
parser.add_argument('--trim', default=False, type=bool, required=False, help='trim the silence ending of generated audio')
parser.add_argument('--format', type=str, default="wav", required=False,
choices=["wav", "mp3", "m4a", "flac"],
help='sampling rate of input audio')
parser.add_argument('--sample_rate', type=int, default=24000, required=False,
help='sampling rate of input audio')
parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
help='sampling rate of generated output audio')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
help='the minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
help='the maximum generated audio length in seconds')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--task',
default='text-to-music',
choices=['text-to-music', 'continuation', "reconstruct", "super_resolution"],
help='choose inference task type. text-to-music: text-to-music task. continuation: music continuation task. reconstruct: reconstruction of original music. super_resolution: convert original 24kHz music into 48kHz music.')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
if args.fast:
args.output_sample_rate = 24000
min_generate_audio_length = int(args.output_sample_rate * args.min_generate_audio_seconds)
max_generate_audio_length = int(args.output_sample_rate * args.max_generate_audio_seconds)
assert args.min_generate_audio_seconds <= args.max_generate_audio_seconds
# Init inspiremusic models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16)
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
if args.llm_model is None:
model.llm = None
else:
model.llm = model.llm.to(torch.float32)
if args.flow_model is None:
model.flow = None
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
caption_fn = os.path.join(args.result_dir, 'captions.txt')
caption_f = open(caption_fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
if "time_start" not in batch.keys():
batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
if batch["time_start"].numpy()[0] > 300:
batch["time_start"] = torch.Tensor([0]).to(torch.float64)
if "time_end" not in batch.keys():
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
else:
if (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) > args.max_generate_audio_seconds:
batch["time_end"] = torch.Tensor([(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds)]).to(torch.float64)
if "chorus" not in batch.keys():
batch["chorus"] = torch.randint(1, 5, (1,))
if args.chorus == "random":
batch["chorus"] = torch.randint(1, 5, (1,))
elif args.chorus == "intro":
batch["chorus"] = torch.Tensor([0])
elif "verse" in args.chorus:
batch["chorus"] = torch.Tensor([1])
elif args.chorus == "chorus":
batch["chorus"] = torch.Tensor([2])
elif args.chorus == "outro":
batch["chorus"] = torch.Tensor([4])
else:
batch["chorus"] = batch["chorus"]
time_start = batch["time_start"].to(device)
time_end = batch["time_end"].to(device)
chorus = batch["chorus"].to(torch.int)
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
chorus = chorus.to(device)
if batch["acoustic_token"] is None:
audio_token = None
audio_token_len = None
else:
audio_token = batch["acoustic_token"].to(device)
audio_token_len = batch["acoustic_token_len"].to(device)
text = batch["text"]
if "semantic_token" in batch:
token = batch["semantic_token"].to(device)
token_len = batch["semantic_token_len"].to(device)
else:
if audio_token is None:
token = None
token_len = None
else:
token = audio_token.view(audio_token.size(0), -1, 4)[:, :, 0]
token_len = audio_token_len / 4
if args.task in ['text-to-music', 'continuation']:
# text to music, music continuation
model_input = {"text": text, "audio_token": token,
"audio_token_len": token_len,
"text_token": text_token,
"text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus],
"raw_text": text,
"sample_rate": args.output_sample_rate,
"duration_to_gen": args.max_generate_audio_seconds,
"task": args.task}
elif args.task in ['reconstruct', 'super_resolution']:
# audio reconstruction, audio super resolution
model_input = {"text": text, "audio_token": audio_token,
"audio_token_len": audio_token_len,
"text_token": text_token,
"text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus],
"raw_text": text,
"sample_rate": args.output_sample_rate,
"duration_to_gen": args.max_generate_audio_seconds,
"task": args.task}
else:
# zero-shot
model_input = {'text' : text,
'text_len' : text_token_len,
'prompt_text' : text_token,
'prompt_text_len' : text_token_len,
'llm_prompt_audio_token' : token,
'llm_prompt_audio_token_len' : token_len,
'flow_prompt_audio_token' : audio_token,
'flow_prompt_audio_token_len': audio_token_len,
'prompt_audio_feat' : audio_feat,
'prompt_audio_feat_len' : audio_feat_len,
"embeddings" : [time_start,
time_end,
chorus]}
music_key = utts[0]
music_audios = []
music_fn = os.path.join(args.result_dir, f'{music_key}.{args.format}')
bench_start = time.time()
for model_output in model.inference(**model_input):
music_audios.append(model_output['music_audio'])
bench_end = time.time()
if args.trim:
music_audio = trim_audio(music_audios[0],
sample_rate=args.output_sample_rate,
threshold=0.05,
min_silence_duration=0.8)
else:
music_audio = music_audios[0]
if music_audio.shape[0] != 0:
if music_audio.shape[1] > max_generate_audio_length:
music_audio = music_audio[:, :max_generate_audio_length]
if music_audio.shape[1] >= min_generate_audio_length:
try:
if args.fade_out:
music_audio = fade_out(music_audio, args.output_sample_rate, args.fade_out_duration)
music_audio = music_audio.repeat(2, 1)
if args.format in ["wav", "flac"]:
torchaudio.save(music_fn, music_audio, sample_rate=args.output_sample_rate, encoding="PCM_S", bits_per_sample=24)
elif args.format in ["mp3", "m4a"]:
torchaudio.backend.sox_io_backend.save(filepath=music_fn, src=music_audio, sample_rate=args.output_sample_rate, format=args.format)
else:
logging.info(f"Format is not supported. Please choose from wav, mp3, m4a, flac.")
except Exception as e:
logging.info(f"Error saving file: {e}")
raise
audio_duration = music_audio.shape[1] / args.output_sample_rate
rtf = (bench_end - bench_start) / audio_duration
logging.info(f"processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
f.write('{} {}\n'.format(music_key, music_fn))
f.flush()
caption_f.write('{}\t{}\n'.format(music_key, text_prompt))
caption_f.flush()
else:
logging.info(f"Generate audio length {music_audio.shape[1]} is shorter than min_generate_audio_length.")
else:
logging.info(f"Generate audio is empty, dim = {music_audio.shape[0]}.")
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import torch
import torch.distributed as dist
import deepspeed
import glob
import os
from hyperpyyaml import load_hyperpyyaml
from torch.cuda.amp import GradScaler, autocast
from torch.distributed.elastic.multiprocessing.errors import record
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from inspiremusic.utils.executor import Executor
from inspiremusic.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='number of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=True,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=30,
type=int,
help='timeout (in seconds) of inspiremusic_join.')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='Enable fp16 mixed precision training')
parser.add_argument('--lora',
action='store_true',
default=False,
help='Enable LoRA training')
parser.add_argument('--lora_rank',
default=4,
type=int,
help='LoRA rank')
parser.add_argument('--lora_alpha',
default=16,
type=int,
help='LoRA alpha')
parser.add_argument('--lora_dropout',
default=0.1,
type=float,
help='LoRA dropout rate')
parser.add_argument('--lora_target_modules',
nargs='+',
default=["k_proj","v_proj"],
help='Target modules to apply LoRA (e.g., ["q_proj", "v_proj"])')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
else:
# Find and load the latest checkpoint
checkpoint_files = glob.glob(os.path.join(args.model_dir, '*.pt'))
if checkpoint_files:
latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
logging.info(f"Loaded latest checkpoint from {latest_checkpoint}")
model.load_state_dict(torch.load(latest_checkpoint, map_location='cpu'))
if args.lora:
logging.info("Applying LoRA to the model...")
if not args.lora_target_modules:
raise ValueError("No target modules specified for LoRA. Please provide --lora_target_modules.")
lora_config = LoraConfig(
task_type="CAUSAL_LM", # Change to appropriate task type
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules
)
model.llm.model = get_peft_model(model.llm.model, lora_config)
# Optionally freeze the base model
else:
logging.info("LoRA is not enabled. Training the full model.")
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
# Initialize AMP for torch_ddp if fp16 is enabled
scaler = None
if args.fp16:
scaler = GradScaler()
logging.info("Initialized AMP GradScaler for mixed precision training.")
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
# Get executor
executor = Executor()
# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
executor.train_one_epoch(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=scaler)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from typing import Callable
import re
import inflect
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
class InspireMusicFrontEnd:
def __init__(self,
configs: Callable,
get_tokenizer: Callable,
llm_model: str,
flow_model: str,
music_tokenizer_dir: str,
audio_tokenizer_dir: str,
instruct: bool = False,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.audio_tokenizer_dir = audio_tokenizer_dir
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bandwidth_id = torch.tensor([0]).to(self.device)
self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_audio_token(self, audio, sample_rate=24000):
audio = torch.tensor(audio, dtype=torch.float32, device=self.device)
_, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id)
audio_token = audio_token.squeeze(0)
audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device)
return audio_token, audio_token_len
def text_normalize(self, text, split=True):
text = text.strip()
if contains_chinese(text):
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
text = text.replace(".", "、")
text = text.replace(" - ", ",")
text = remove_bracket(text)
text = re.sub(r'[,,]+$', '。', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
else:
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
if split is False:
return text
return texts
def frontend_text_to_music(self, text, time_start, time_end, chorus):
text_token, text_token_len = self._extract_text_token(text)
model_input = {"text": text, "audio_token": None, "audio_token_len": None,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
return model_input
def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000):
if text is None:
text_token = None
text_token_len = None
else:
text_token, text_token_len = self._extract_text_token(text)
audio_token, audio_token_len = self._extract_audio_token(audio, target_sr)
model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
return model_input
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torchaudio
import time
import logging
import argparse
from inspiremusic.cli.inspiremusic import InspireMusic
from inspiremusic.utils.file_utils import logging
import torch
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
def set_env_variables():
os.environ['PYTHONIOENCODING'] = 'UTF-8'
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
current_working_dir = os.getcwd()
main_root = os.path.realpath(os.path.join(current_working_dir, '../../'))
bin_dir = os.path.join(main_root, 'inspiremusic')
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
os.environ['PYTHONPATH'] = python_path
sys.path.extend([main_root, third_party_matcha_tts_path])
class InspireMusicUnified:
def __init__(self,
model_name: str = "InspireMusic-1.5B-Long",
model_dir: str = None,
min_generate_audio_seconds: float = 10.0,
max_generate_audio_seconds: float = 30.0,
sample_rate: int = 24000,
output_sample_rate: int = 48000,
load_jit: bool = True,
load_onnx: bool = False,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
gpu: int = 0,
result_dir: str = None,
hub="modelscope"):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
# Set model_dir or default to downloading if it doesn't exist
if model_dir is None:
model_dir = f"../../pretrained_models/{model_name}"
if not os.path.isfile(f"{model_dir}/llm.pt"):
if hub == "modelscope":
from modelscope import snapshot_download
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
self.model_dir = model_dir
self.sample_rate = sample_rate
self.output_sample_rate = 24000 if fast else output_sample_rate
self.result_dir = result_dir or f"exp/{model_name}"
os.makedirs(self.result_dir, exist_ok=True)
self.min_generate_audio_seconds = min_generate_audio_seconds
self.max_generate_audio_seconds = max_generate_audio_seconds
self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds)
self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds)
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
use_cuda = gpu >= 0 and torch.cuda.is_available()
self.device = torch.device('cuda' if use_cuda else 'cpu')
self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, dtype=dtype, fast=fast, fp16=fp16)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@torch.inference_mode()
def inference(self,
task: str = 'text-to-music',
text: str = None,
audio_prompt: str = None, # audio prompt file path
instruct: str = None,
chorus: str = "verse",
time_start: float = 0.0,
time_end: float = 30.0,
output_fn: str = "output_audio",
max_audio_prompt_length: float = 5.0,
fade_out_duration: float = 1.0,
output_format: str = "wav",
fade_out_mode: bool = True,
trim: bool = False,
):
with torch.no_grad():
text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>"
chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4}
chorus = chorus_dict.get(chorus, 1)
chorus = torch.tensor([chorus], dtype=torch.int).to(self.device)
time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device)
time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device)
music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}')
bench_start = time.time()
if task == 'text-to-music':
model_input = {
"text" : text,
"audio_prompt" : audio_prompt,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
elif task == 'continuation':
if audio_prompt is not None:
audio, _ = process_audio(audio_prompt, self.sample_rate)
if audio.size(1) < self.sample_rate:
logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.")
audio = None
else:
max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate)
audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio
model_input = {
"text" : text,
"audio_prompt" : audio,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
music_audios = []
for model_output in self.model.cli_inference(**model_input):
music_audios.append(model_output['music_audio'])
bench_end = time.time()
if trim:
music_audio = trim_audio(music_audios[0],
sample_rate=self.output_sample_rate,
threshold=0.05,
min_silence_duration=0.8)
else:
music_audio = music_audios[0]
if music_audio.shape[0] != 0:
if music_audio.shape[1] > self.max_generate_audio_length:
music_audio = music_audio[:, :self.max_generate_audio_length]
if music_audio.shape[1] >= self.min_generate_audio_length:
try:
if fade_out_mode:
music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration)
music_audio = music_audio.repeat(2, 1)
if output_format in ["wav", "flac"]:
torchaudio.save(music_fn, music_audio,
sample_rate=self.output_sample_rate,
encoding="PCM_S",
bits_per_sample=24)
elif output_format in ["mp3", "m4a"]:
torchaudio.backend.sox_io_backend.save(
filepath=music_fn, src=music_audio,
sample_rate=self.output_sample_rate,
format=output_format)
else:
logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.")
except Exception as e:
logging.error(f"Error saving file: {e}")
raise
audio_duration = music_audio.shape[1] / self.output_sample_rate
rtf = (bench_end - bench_start) / audio_duration
logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
else:
logging.error(f"Generated audio length is shorter than minimum required audio length.")
if music_fn:
if os.path.exists(music_fn):
logging.info(f"Generated audio file {music_fn} is saved.")
return music_fn
else:
logging.error(f"{music_fn} does not exist.")
def get_args():
parser = argparse.ArgumentParser(description='Run inference with your model')
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
help='Model name')
parser.add_argument('-d', '--model_dir',
help='Model folder path')
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
help='Prompt text')
parser.add_argument('-a', '--audio_prompt', default=None,
help='Prompt audio')
parser.add_argument('-c', '--chorus', default="intro",
help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
parser.add_argument('-f', '--fast', type=bool, default=False,
help='Enable fast inference mode (without flow matching)')
parser.add_argument('-g', '--gpu', type=int, default=0,
help='GPU ID for this rank, -1 for CPU')
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
help='Directory to save generated audio')
parser.add_argument('-o', '--output_fn', default="output_audio",
help='Output file name')
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
help='Format of output audio')
parser.add_argument('--sample_rate', type=int, default=24000,
help='Sampling rate of input audio')
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
help='Sampling rate of generated output audio')
parser.add_argument('-s', '--time_start', type=float, default=0.0,
help='Start time in seconds')
parser.add_argument('-e', '--time_end', type=float, default=30.0,
help='End time in seconds')
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
help='Maximum audio prompt length in seconds')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
help='Minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0,
help='Maximum generated audio length in seconds')
parser.add_argument('--fp16', type=bool, default=True,
help='Inference with fp16 model')
parser.add_argument('--fade_out', type=bool, default=True,
help='Apply fade out effect to generated audio')
parser.add_argument('--fade_out_duration', type=float, default=1.0,
help='Fade out duration in seconds')
parser.add_argument('--trim', type=bool, default=False,
help='Trim the silence ending of generated audio')
args = parser.parse_args()
if not args.model_dir:
args.model_dir = os.path.join("../../pretrained_models", args.model_name)
print(args)
return args
def main():
set_env_variables()
args = get_args()
model = InspireMusicUnified(model_name = args.model_name,
model_dir = args.model_dir,
min_generate_audio_seconds = args.min_generate_audio_seconds,
max_generate_audio_seconds = args.max_generate_audio_seconds,
sample_rate = args.sample_rate,
output_sample_rate = args.output_sample_rate,
load_jit = True,
load_onnx = False,
dtype="fp16",
fast = args.fast,
fp16 = args.fp16,
gpu = args.gpu,
result_dir = args.result_dir)
model.inference(task = args.task,
text = args.text,
audio_prompt = args.audio_prompt,
chorus = args.chorus,
time_start = args.time_start,
time_end = args.time_end,
output_fn = args.output_fn,
max_audio_prompt_length = args.max_audio_prompt_length,
fade_out_duration = args.fade_out_duration,
output_format = args.format,
fade_out_mode = args.fade_out,
trim = args.trim)
if __name__ == "__main__":
main()
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from inspiremusic.cli.frontend import InspireMusicFrontEnd
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.utils.file_utils import logging
import torch
class InspireMusic:
def __init__(self, model_dir, load_jit=True, load_onnx=False, dtype = "fp16", fast = False, fp16=True, hub="modelscope"):
instruct = True if '-Instruct' in model_dir else False
if model_dir is None:
model_dir = f"../../pretrained_models/InspireMusic-1.5B-Long"
if not os.path.isfile(f"{model_dir}/llm.pt"):
model_name = model_dir.split("/")[-1]
if hub == "modelscope":
from modelscope import snapshot_download
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f)
self.frontend = InspireMusicFrontEnd(configs,
configs['get_tokenizer'],
'{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/music_tokenizer/'.format(model_dir),
'{}/wavtokenizer/'.format(model_dir),
instruct,
dtype,
fast,
fp16,
configs['allowed_special'])
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/music_tokenizer/'.format(model_dir),
'{}/wavtokenizer/model.pt'.format(model_dir))
del configs
@torch.inference_mode()
def inference(self, task, text, audio, time_start, time_end, chorus, stream=False, sr=24000):
if task == "text-to-music":
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_text_to_music(i, time_start, time_end, chorus)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
elif task == "continuation":
if text is None:
if audio is not None:
for i in tqdm(audio):
model_input = self.frontend.frontend_continuation(None, i, time_start, time_end, chorus, sr, max_audio_length)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.continuation_inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
else:
if audio is not None:
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_continuation(i, audio, time_start, time_end, chorus, sr, max_audio_length)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.continuation_inference(**model_input, stream=stream):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
else:
print("Please input text or audio.")
else:
print("Currently only support text-to-music and music continuation tasks.")
@torch.inference_mode()
def cli_inference(self, text, audio_prompt, time_start, time_end, chorus, task, stream=False, duration_to_gen=30, sr=24000):
if task == "text-to-music":
model_input = self.frontend.frontend_text_to_music(text, time_start, time_end, chorus)
logging.info('prompt text {}'.format(text))
elif task == "continuation":
model_input = self.frontend.frontend_continuation(text, audio_prompt, time_start, time_end, chorus, sr)
logging.info('prompt audio length: {}'.format(len(audio_prompt)))
start_time = time.time()
for model_output in self.model.inference(**model_input, duration_to_gen=duration_to_gen, task=task):
music_audios_len = model_output['music_audio'].shape[1] / sr
logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
yield model_output
start_time = time.time()
@torch.inference_mode()
def inference_zero_shot(self, text, prompt_text, prompt_audio_16k, stream=False, sr=24000):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_audio_16k)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
audio_len = model_output['music_audio'].shape[1] / sr
logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
yield model_output
start_time = time.time()
@torch.inference_mode()
def inference_instruct(self, text, spk_id, instruct_text, stream=False, sr=24000):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
for i in tqdm(self.frontend.text_normalize(text, split=True)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('prompt text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
audio_len = model_output['music_audio'].shape[1] / sr
logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
yield model_output
start_time = time.time()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import threading
import time
from contextlib import nullcontext
import uuid
from inspiremusic.utils.common import DTYPES
from inspiremusic.music_tokenizer.vqvae import VQVAE
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
from torch.cuda.amp import autocast
import logging
import torch
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class InspireMusicModel:
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
music_tokenizer: torch.nn.Module,
wavtokenizer: torch.nn.Module,
dtype: str = "fp16",
fast: bool = False,
fp16: bool = True,
):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.dtype = DTYPES.get(dtype, torch.float32)
self.llm = llm.to(self.dtype)
self.flow = flow
self.music_tokenizer = music_tokenizer
self.wavtokenizer = wavtokenizer
self.fp16 = fp16
self.token_min_hop_len = 100
self.token_max_hop_len = 200
self.token_overlap_len = 20
# mel fade in out
self.mel_overlap_len = 34
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
self.source_cache_len = int(self.mel_cache_len * 256)
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.music_token_dict = {}
self.llm_end_dict = {}
self.mel_overlap_dict = {}
self.fast = fast
self.generator = "hifi"
def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
if llm_model is not None:
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.to(self.device).to(self.dtype).eval()
else:
self.llm = None
if flow_model is not None:
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval()
if hift_model is not None:
if ".pt" not in hift_model:
self.music_tokenizer = VQVAE( hift_model + '/config.json',
hift_model + '/model.pt', with_encoder=True)
else:
self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json',
hift_model, with_encoder=True)
self.music_tokenizer.to(self.device).eval()
if wavtokenizer_model is not None:
if ".pt" not in wavtokenizer_model:
self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml',
wavtokenizer_model + '/model.pt')
else:
self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml',
wavtokenizer_model )
self.wavtokenizer.to(self.device)
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model)
self.flow.encoder = flow_encoder
def load_onnx(self, flow_decoder_estimator_model):
import onnxruntime
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
del self.flow.decoder.estimator
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
with self.llm_context:
local_res = []
with autocast(enabled=self.fp16, dtype=self.dtype, cache_enabled=True):
inference_kwargs = {
'text': text.to(self.device),
'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
'prompt_text': prompt_text.to(self.device),
'prompt_text_len': torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
'prompt_audio_token': llm_prompt_audio_token.to(self.device),
'prompt_audio_token_len': torch.tensor([llm_prompt_audio_token.shape[1]], dtype=torch.int32).to(self.device),
'embeddings': embeddings,
'duration_to_gen': duration_to_gen,
'task': task
}
if audio_token is not None:
inference_kwargs['audio_token'] = audio_token.to(self.device)
else:
inference_kwargs['audio_token'] = torch.Tensor([0]).to(self.device)
if audio_token_len is not None:
inference_kwargs['audio_token_len'] = audio_token_len.to(self.device)
else:
inference_kwargs['audio_token_len'] = torch.Tensor([0]).to(self.device)
for i in self.llm.inference(**inference_kwargs):
local_res.append(i)
self.music_token_dict[uuid] = local_res
self.llm_end_dict[uuid] = True
# def token2wav(self, token, token_len, text, text_len, uuid, sample_rate, finalize=False):
def token2wav(self, token, token_len, uuid, sample_rate, finalize=False, flow_cfg=None):
# if self.flow is not None:
# if isinstance(self.flow,MaskedDiffWithText):
# codec_embed = self.flow.inference(token=token.to(self.device),
# token_len=token_len.to(self.device),
# text_token=text,
# text_token_len=text_len,
# )
# else:
if flow_cfg is not None:
codec_embed = self.flow.inference_cfg(token=token.to(self.device),
token_len=token_len.to(self.device),
sample_rate=sample_rate
)
else:
codec_embed = self.flow.inference(token=token.to(self.device),
token_len=token_len.to(self.device),
sample_rate=sample_rate
)
# use music_tokenizer decoder
wav = self.music_tokenizer.generator(codec_embed)
wav = wav.squeeze(0).cpu().detach()
return wav
def acoustictoken2wav(self, token):
# use music_tokenizer to generate waveform from token
token = token.view(token.size(0), -1, 4)
# codec = token.view(1, -1, 4)
codec_embed = self.music_tokenizer.quantizer.embed(torch.tensor(token).long().to(self.device)).cuda()
wav = self.music_tokenizer.generator(codec_embed)
wav = wav.squeeze(0).cpu().detach()
return wav
def semantictoken2wav(self, token):
# fast mode, use wavtokenizer decoder
new_tensor = torch.tensor(token.to(self.device)).unsqueeze(0)
features = self.wavtokenizer.codes_to_features(new_tensor)
bandwidth_id = torch.tensor([0]).to(self.device)
wav = self.wavtokenizer.to(self.device).decode(features, bandwidth_id=bandwidth_id)
wav = wav.cpu().detach()
return wav
@torch.inference_mode()
def inference(self, text, audio_token, audio_token_len, text_token, text_token_len, embeddings=None,
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_audio_feat=torch.zeros(1, 0, 80), sample_rate=48000, duration_to_gen = 30, task="continuation", trim = True, stream=False, **kwargs):
# this_uuid is used to track variables related to this inference thread
# support tasks:
# text to music task
# music continuation task
# require either audio input only or text and audio inputs
this_uuid = str(uuid.uuid1())
if self.llm:
with self.lock:
self.music_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
p = threading.Thread(target=self.llm_job, args=(text_token, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, this_uuid, duration_to_gen, task))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.music_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_music_audio = self.token2wav(token=text_token,
token_len=text_token_len,
uuid=this_uuid,
sample_rate=sample_rate,
finalize=False)
yield {'music_audio': this_music_audio.cpu()}
with self.lock:
self.music_token_dict[this_uuid] = self.music_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better audio quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.music_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_music_audio = self.token2wav(token=this_music_token,
prompt_token=flow_prompt_audio_token,
prompt_feat=prompt_audio_feat,
embedding=flow_embedding,
uuid=this_uuid,
sample_rate=sample_rate,
finalize=True)
yield {'music_audio': this_music_audio.cpu()}
else:
# deal with all tokens
if self.fast:
if task == "reconstruct":
assert audio_token is None
this_music_token = audio_token
this_music_audio = self.acoustictoken2wav(token=this_music_token)
else:
if self.llm:
p.join()
print(len(self.music_token_dict[this_uuid]))
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
print(this_music_token.shape)
else:
this_music_token = text_token
logging.info("using wavtokenizer generator without flow matching")
this_music_audio = self.semantictoken2wav(token=this_music_token)
print(this_music_audio.shape)
else:
if self.llm:
p.join()
if len(self.music_token_dict[this_uuid]) != 0:
this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
else:
print(f"The list of tensors is empty for UUID: {this_uuid}")
else:
this_music_token = text_token
logging.info(f"LLM generated audio token length: {this_music_token.shape[1]}")
logging.info(f"using flow matching and {self.generator} generator")
if self.generator == "hifi":
if (embeddings[1] - embeddings[0]) <= duration_to_gen:
if trim:
trim_length = (int((embeddings[1] - embeddings[0])*75))
this_music_token = this_music_token[:, :trim_length]
logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
elif (embeddings[1] - embeddings[0]) < 1:
logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
this_music_audio = self.token2wav(token=this_music_token,
token_len=torch.LongTensor([this_music_token.size(1)]),
uuid=this_uuid,
sample_rate=sample_rate,
finalize=True)
logging.info(f"Generated audio sequence length: {this_music_audio.shape[1]}")
elif self.generator == "wavtokenizer":
if (embeddings[1] - embeddings[0]) < duration_to_gen:
if trim:
trim_length = (int((embeddings[1] - embeddings[0])*75))
this_music_token = this_music_token[:,:trim_length]
logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
elif (embeddings[1] - embeddings[0]) < 1:
logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
this_music_audio = self.semantictoken2wav(token=this_music_token)
yield {'music_audio': this_music_audio.cpu()}
torch.cuda.synchronize()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment