from io import BytesIO import sys import librosa import numpy as np import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig from ..utils.misc import print_once from .base import BaseModel from vita_audio.data.processor.audio_processor import add_audio_input_contiguous from vita_audio.tokenizer import get_audio_tokenizer chat_template = """ {%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n """ class VITAAudio(BaseModel): NAME = 'VITA-Audio' def __init__(self, model_path="VITA-MLLM/VITA-Audio-Plus-Boost", device='cuda', torch_dtype=torch.bfloat16, **kwargs): self.device = device self.config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) self.vita_model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ).to(device).eval() self.vita_tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, chat_template=chat_template, ) self.vita_model.generation_config = GenerationConfig.from_pretrained( model_path, trust_remote_code=True ) self.vita_model.generation_config.max_new_tokens = 2048 self.vita_model.generation_config.chat_format = "chatml" self.vita_model.generation_config.max_window_size = 2048 self.vita_model.generation_config.use_cache = True # self.vita_model.generation_config.use_cache = False self.vita_model.generation_config.do_sample = False sys.path.append("glm4voice/") sys.path.append("glm4voice/cosyvoice/") sys.path.append("glm4voice/third_party/Matcha-TTS/") audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer" flow_path = "/data/models/THUDM/glm-4-voice-decoder" audio_tokenizer_type = "sensevoice_glm4voice" self.audio_tokenizer = get_audio_tokenizer( audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path, # rank=audio_tokenizer_rank, ) self.default_system_message = [ ] self.luke_system_message = [ { "role": "system", "content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.", }, ] self.add_generation_prompt = True torch.cuda.empty_cache() def get_system_message(self, msg: dict): meta = msg['meta'] if meta is None: return self.default_system_message if meta['task'] == 'ASR': return self.default_system_message return self.luke_system_message def get_task_message(self, msg: dict): meta = msg['meta'] if meta['task'] == 'ASR': messages = [ { "role": "user", "content": "Convert the speech to text.\n<|audio|>", }, ] elif meta['interactive'] == 'Audio-QA': messages = [ { "role": "user", "content": "<|audio|>", }, ] elif meta['audio_type'] == 'AudioEvent': messages = [ { "role": "user", "content": msg['text'] + "\n<|audio|>", }, ] else: messages = [ { "role": "user", "content": msg['text'] + "\n<|audio|>", }, ] return messages def generate_inner(self, msg: dict): audio_path = msg['audio'] if len(audio_path) == 1: audio_path = audio_path[0] prompt_audio_path = None messages = self.get_task_message(msg) system_message = self.get_system_message(msg) # only for dump messages = system_message + messages print_once(f'messages: {messages}') if prompt_audio_path is not None: if self.audio_tokenizer.apply_to_role("system", is_discrete=True): # discrete codec prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens) system_message = [ { "role": "system", "content": f"Your Voice: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n", }, ] else: # contiguous codec system_message = system_message if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): # discrete codec audio_tokens = self.audio_tokenizer.encode(audio_path) audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) messages[-1]["content"] = messages[-1]["content"].replace( "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" ) input_ids = self.vita_tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=self.add_generation_prompt, ) if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_contiguous=True): # contiguous codec input_ids, audios, audio_indices = add_audio_input_contiguous( input_ids, [audio_path], self.vita_tokenizer, self.audio_tokenizer ) else: audios = None audio_indices = None input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") responses = self.vita_model.generate( input_ids, audios=audios, audio_indices=audio_indices, ) response = responses[0][len(input_ids[0]) :] # audio_offset = self.vita_tokenizer.convert_tokens_to_ids("<|audio_0|>") audio_offset = self.vita_tokenizer.convert_tokens_to_ids("<|begin_of_audio|>") audio_tokens = [] text_tokens = [] for token_id in response: if token_id >= audio_offset: audio_tokens.append(token_id - audio_offset) else: text_tokens.append(token_id) # if len(audio_tokens) > 0: # tts_speech = self.audio_tokenizer.decode( # audio_tokens, source_speech_16k=prompt_audio_path # ) # else: # tts_speech = None out_text = self.vita_tokenizer.decode( text_tokens, skip_special_tokens=True, ) # print_once(f'{out_text=}') return self.vita_tokenizer.decode(input_ids[0], skip_special_tokens=False), out_text