Commit 2a850926 authored by weedge's avatar weedge
Browse files

add inference mode for lm generate


Signed-off-by: default avatarweedge <weege007@gmail.com>
parent 0004a354
......@@ -47,6 +47,7 @@ class KimiAudio(object):
self.kimia_text_audiodelaytokens = 6
self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]
@torch.inference_mode()
def _generate_loop(
self,
audio_input_ids: torch.Tensor, # input audio tokens
......@@ -204,6 +205,7 @@ class KimiAudio(object):
)
return return_audio_tokens, return_text_tokens
@torch.inference_mode()
def generate(
self,
chats: list[dict],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment