Unverified Commit 9ab5ac25 authored by bigmoyan's avatar bigmoyan Committed by GitHub
Browse files

Merge pull request #21 from weedge/fix/no_grade

lm generate add torch inference mode wrap
parents 44de67a3 2a850926
...@@ -48,6 +48,7 @@ class KimiAudio(object): ...@@ -48,6 +48,7 @@ class KimiAudio(object):
self.kimia_text_audiodelaytokens = 6 self.kimia_text_audiodelaytokens = 6
self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end] self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]
@torch.inference_mode()
def _generate_loop( def _generate_loop(
self, self,
audio_input_ids: torch.Tensor, # input audio tokens audio_input_ids: torch.Tensor, # input audio tokens
...@@ -205,6 +206,7 @@ class KimiAudio(object): ...@@ -205,6 +206,7 @@ class KimiAudio(object):
) )
return return_audio_tokens, return_text_tokens return return_audio_tokens, return_text_tokens
@torch.inference_mode()
def generate( def generate(
self, self,
chats: list[dict], chats: list[dict],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment