Commit 2a850926 authored by weedge's avatar weedge
Browse files

add inference mode for lm generate


Signed-off-by: default avatarweedge <weege007@gmail.com>
parent 0004a354
...@@ -47,6 +47,7 @@ class KimiAudio(object): ...@@ -47,6 +47,7 @@ class KimiAudio(object):
self.kimia_text_audiodelaytokens = 6 self.kimia_text_audiodelaytokens = 6
self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end] self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]
@torch.inference_mode()
def _generate_loop( def _generate_loop(
self, self,
audio_input_ids: torch.Tensor, # input audio tokens audio_input_ids: torch.Tensor, # input audio tokens
...@@ -204,6 +205,7 @@ class KimiAudio(object): ...@@ -204,6 +205,7 @@ class KimiAudio(object):
) )
return return_audio_tokens, return_text_tokens return return_audio_tokens, return_text_tokens
@torch.inference_mode()
def generate( def generate(
self, self,
chats: list[dict], chats: list[dict],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment