Unverified Commit 44de67a3 authored by bigmoyan's avatar bigmoyan Committed by GitHub
Browse files

Merge pull request #20 from MoonshotAI/fix-bug-cannot-load-from-model-id

Fix bug: cannot load from model-id
parents 8d79a4e4 da6e22f7
......@@ -9,13 +9,23 @@ from transformers import AutoModelForCausalLM
from kimia_infer.models.detokenizer import get_audio_detokenizer
from .prompt_manager import KimiAPromptManager
from kimia_infer.utils.sampler import KimiASampler
from huggingface_hub import snapshot_download
class KimiAudio(object):
def __init__(self, model_path: str, load_detokenizer: bool = True):
logger.info(f"Loading kimi-audio main model")
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# cache everything if model_path is a model-id
cache_path = snapshot_download(model_path)
logger.info(f"Looking for resources in {cache_path}")
logger.info(f"Loading whisper model")
self.alm = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
cache_path, torch_dtype=torch.bfloat16, trust_remote_code=True
)
self.alm = self.alm.to(torch.cuda.current_device())
......@@ -23,18 +33,9 @@ class KimiAudio(object):
self.kimia_token_offset = model_config.kimia_token_offset
self.prompt_manager = KimiAPromptManager(
model_path=model_path, kimia_token_offset=self.kimia_token_offset
model_path=cache_path, kimia_token_offset=self.kimia_token_offset
)
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
)
if load_detokenizer:
logger.info(f"Loading detokenizer")
# need to compile extension moudules for the first time, it may take several minutes.
......
......@@ -4,7 +4,6 @@ import os
import librosa
import torch
from loguru import logger
from huggingface_hub import cached_assets_path
from transformers import AutoTokenizer
......@@ -13,25 +12,16 @@ from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer
from kimia_infer.utils.data import KimiAContent
from kimia_infer.utils.special_tokens import instantiate_extra_tokens
class KimiAPromptManager:
def __init__(self, model_path: str, kimia_token_offset: int):
self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer")
self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device())
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
)
logger.info(f"Looking for resources in {cache_path}")
logger.info(f"Looking for resources in {model_path}")
logger.info(f"Loading whisper model")
self.whisper_model = WhisperEncoder(
os.path.join(cache_path, "whisper-large-v3"), mel_batch_size=20
os.path.join(model_path, "whisper-large-v3"), mel_batch_size=20
)
self.whisper_model = self.whisper_model.to(torch.cuda.current_device())
self.whisper_model = self.whisper_model.bfloat16()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment