Commit a8d3703b authored by AlanSwift's avatar AlanSwift
Browse files

Add multiturn audio processing in infer.py and update KimiAPromptManager to...

Add multiturn audio processing in infer.py and update KimiAPromptManager to handle audio-text messages. Introduce new audio and text files for testing multiturn cases.
parent 430be2d4
...@@ -51,3 +51,56 @@ if __name__ == "__main__": ...@@ -51,3 +51,56 @@ if __name__ == "__main__":
24000, 24000,
) )
print(">>> output text: ", text) print(">>> output text: ", text)
# audio2audio multiturn
messages = [
{
"role": "user",
"message_type": "audio",
"content": "test_audios/multiturn/case1/multiturn_q1.wav",
},
{
"role": "assistant",
"message_type": "audio-text",
"content": ["test_audios/multiturn/case1/multiturn_a1.wav", "当然可以,李白的诗很多,比如这句:“床前明月光,疑是地上霜。举头望明月,低头思故乡。"]
},
{
"role": "user",
"message_type": "audio",
"content": "test_audios/multiturn/case1/multiturn_q2.wav",
}
]
wav, text = model.generate(messages, **sampling_params, output_type="both")
sf.write(
os.path.join(output_dir, "case_1_multiturn_a2.wav"),
wav.detach().cpu().view(-1).numpy(),
24000,
)
print(">>> output text: ", text)
messages = [
{
"role": "user",
"message_type": "audio",
"content": "test_audios/multiturn/case2/multiturn_q1.wav",
},
{
"role": "assistant",
"message_type": "audio-text",
"content": ["test_audios/multiturn/case2/multiturn_a1.wav", "当然可以,这很简单。一二三四五六七八九十。"]
},
{
"role": "user",
"message_type": "audio",
"content": "test_audios/multiturn/case2/multiturn_q2.wav",
}
]
wav, text = model.generate(messages, **sampling_params, output_type="both")
sf.write(
os.path.join(output_dir, "case_2_multiturn_a2.wav"),
wav.detach().cpu().view(-1).numpy(),
24000,
)
print(">>> output text: ", text)
...@@ -30,10 +30,11 @@ class KimiAudio(object): ...@@ -30,10 +30,11 @@ class KimiAudio(object):
self.alm = self.alm.to(torch.cuda.current_device()) self.alm = self.alm.to(torch.cuda.current_device())
model_config = self.alm.config model_config = self.alm.config
self.kimia_text_audiodelaytokens = model_config.kimia_mimo_audiodelaytokens
self.kimia_token_offset = model_config.kimia_token_offset self.kimia_token_offset = model_config.kimia_token_offset
self.prompt_manager = KimiAPromptManager( self.prompt_manager = KimiAPromptManager(
model_path=cache_path, kimia_token_offset=self.kimia_token_offset model_path=cache_path, kimia_token_offset=self.kimia_token_offset, kimia_text_audiodelaytokens=self.kimia_text_audiodelaytokens
) )
if load_detokenizer: if load_detokenizer:
...@@ -45,7 +46,6 @@ class KimiAudio(object): ...@@ -45,7 +46,6 @@ class KimiAudio(object):
self.detokenizer = None self.detokenizer = None
self.extra_tokens = self.prompt_manager.extra_tokens self.extra_tokens = self.prompt_manager.extra_tokens
self.kimia_text_audiodelaytokens = 6
self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end] self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]
@torch.inference_mode() @torch.inference_mode()
......
...@@ -13,7 +13,7 @@ from kimia_infer.utils.data import KimiAContent ...@@ -13,7 +13,7 @@ from kimia_infer.utils.data import KimiAContent
from kimia_infer.utils.special_tokens import instantiate_extra_tokens from kimia_infer.utils.special_tokens import instantiate_extra_tokens
class KimiAPromptManager: class KimiAPromptManager:
def __init__(self, model_path: str, kimia_token_offset: int): def __init__(self, model_path: str, kimia_token_offset: int, kimia_text_audiodelaytokens: int):
self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer") self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer")
self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device()) self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device())
...@@ -34,6 +34,8 @@ class KimiAPromptManager: ...@@ -34,6 +34,8 @@ class KimiAPromptManager:
self.extra_tokens = instantiate_extra_tokens(self.text_tokenizer) self.extra_tokens = instantiate_extra_tokens(self.text_tokenizer)
self.kimia_text_audiodelaytokens = kimia_text_audiodelaytokens
self.kimia_token_offset = kimia_token_offset self.kimia_token_offset = kimia_token_offset
def _tokenize_text(self, text): def _tokenize_text(self, text):
...@@ -124,6 +126,17 @@ class KimiAPromptManager: ...@@ -124,6 +126,17 @@ class KimiAPromptManager:
if extract_whisper_feature: if extract_whisper_feature:
whisper_feature = self.extract_whisper_feat(audio_path) whisper_feature = self.extract_whisper_feat(audio_path)
kimia_content_msg.continuous_feature.append(whisper_feature) kimia_content_msg.continuous_feature.append(whisper_feature)
elif message["message_type"] == "audio-text":
audio_path, text = message["content"]
speech_tokens = self._tokenize_audio(audio_path)
text_tokens = self._tokenize_text(text)
kimia_content_msg.audio_extend([self.extra_tokens.kimia_text_blank] * self.kimia_text_audiodelaytokens)
kimia_content_msg.audio_extend(speech_tokens, is_continuous=False)
kimia_content_msg.text_extend(text_tokens)
text_pad_tokens = (self.kimia_text_audiodelaytokens + len(speech_tokens) - len(text_tokens)) * [self.extra_tokens.kimia_text_blank]
kimia_content_msg.text_extend(text_pad_tokens)
elif message["message_type"] == None: elif message["message_type"] == None:
pass pass
else: else:
......
当然可以,李白的诗很多,比如这句:“床前明月光,疑是地上霜。举头望明月,低头思故乡。
\ No newline at end of file
当然可以,这很简单。一二三四五六七八九十。
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment