Unverified Commit 44de67a3 authored by bigmoyan's avatar bigmoyan Committed by GitHub
Browse files

Merge pull request #20 from MoonshotAI/fix-bug-cannot-load-from-model-id

Fix bug: cannot load from model-id
parents 8d79a4e4 da6e22f7
...@@ -9,13 +9,23 @@ from transformers import AutoModelForCausalLM ...@@ -9,13 +9,23 @@ from transformers import AutoModelForCausalLM
from kimia_infer.models.detokenizer import get_audio_detokenizer from kimia_infer.models.detokenizer import get_audio_detokenizer
from .prompt_manager import KimiAPromptManager from .prompt_manager import KimiAPromptManager
from kimia_infer.utils.sampler import KimiASampler from kimia_infer.utils.sampler import KimiASampler
from huggingface_hub import snapshot_download
class KimiAudio(object): class KimiAudio(object):
def __init__(self, model_path: str, load_detokenizer: bool = True): def __init__(self, model_path: str, load_detokenizer: bool = True):
logger.info(f"Loading kimi-audio main model") logger.info(f"Loading kimi-audio main model")
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# cache everything if model_path is a model-id
cache_path = snapshot_download(model_path)
logger.info(f"Looking for resources in {cache_path}")
logger.info(f"Loading whisper model")
self.alm = AutoModelForCausalLM.from_pretrained( self.alm = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, trust_remote_code=True cache_path, torch_dtype=torch.bfloat16, trust_remote_code=True
) )
self.alm = self.alm.to(torch.cuda.current_device()) self.alm = self.alm.to(torch.cuda.current_device())
...@@ -23,16 +33,7 @@ class KimiAudio(object): ...@@ -23,16 +33,7 @@ class KimiAudio(object):
self.kimia_token_offset = model_config.kimia_token_offset self.kimia_token_offset = model_config.kimia_token_offset
self.prompt_manager = KimiAPromptManager( self.prompt_manager = KimiAPromptManager(
model_path=model_path, kimia_token_offset=self.kimia_token_offset model_path=cache_path, kimia_token_offset=self.kimia_token_offset
)
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
) )
if load_detokenizer: if load_detokenizer:
......
...@@ -4,7 +4,6 @@ import os ...@@ -4,7 +4,6 @@ import os
import librosa import librosa
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import cached_assets_path
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -13,25 +12,16 @@ from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer ...@@ -13,25 +12,16 @@ from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer
from kimia_infer.utils.data import KimiAContent from kimia_infer.utils.data import KimiAContent
from kimia_infer.utils.special_tokens import instantiate_extra_tokens from kimia_infer.utils.special_tokens import instantiate_extra_tokens
class KimiAPromptManager: class KimiAPromptManager:
def __init__(self, model_path: str, kimia_token_offset: int): def __init__(self, model_path: str, kimia_token_offset: int):
self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer") self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer")
self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device()) self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device())
if os.path.exists(model_path): logger.info(f"Looking for resources in {model_path}")
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
)
logger.info(f"Looking for resources in {cache_path}")
logger.info(f"Loading whisper model") logger.info(f"Loading whisper model")
self.whisper_model = WhisperEncoder( self.whisper_model = WhisperEncoder(
os.path.join(cache_path, "whisper-large-v3"), mel_batch_size=20 os.path.join(model_path, "whisper-large-v3"), mel_batch_size=20
) )
self.whisper_model = self.whisper_model.to(torch.cuda.current_device()) self.whisper_model = self.whisper_model.to(torch.cuda.current_device())
self.whisper_model = self.whisper_model.bfloat16() self.whisper_model = self.whisper_model.bfloat16()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment