Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Kimi-Audio
Commits
44de67a3
Unverified
Commit
44de67a3
authored
Apr 27, 2025
by
bigmoyan
Committed by
GitHub
Apr 27, 2025
Browse files
Merge pull request #20 from MoonshotAI/fix-bug-cannot-load-from-model-id
Fix bug: cannot load from model-id
parents
8d79a4e4
da6e22f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
24 deletions
+15
-24
kimia_infer/api/kimia.py
kimia_infer/api/kimia.py
+13
-12
kimia_infer/api/prompt_manager.py
kimia_infer/api/prompt_manager.py
+2
-12
No files found.
kimia_infer/api/kimia.py
View file @
44de67a3
...
@@ -9,13 +9,23 @@ from transformers import AutoModelForCausalLM
...
@@ -9,13 +9,23 @@ from transformers import AutoModelForCausalLM
from
kimia_infer.models.detokenizer
import
get_audio_detokenizer
from
kimia_infer.models.detokenizer
import
get_audio_detokenizer
from
.prompt_manager
import
KimiAPromptManager
from
.prompt_manager
import
KimiAPromptManager
from
kimia_infer.utils.sampler
import
KimiASampler
from
kimia_infer.utils.sampler
import
KimiASampler
from
huggingface_hub
import
snapshot_download
class
KimiAudio
(
object
):
class
KimiAudio
(
object
):
def
__init__
(
self
,
model_path
:
str
,
load_detokenizer
:
bool
=
True
):
def
__init__
(
self
,
model_path
:
str
,
load_detokenizer
:
bool
=
True
):
logger
.
info
(
f
"Loading kimi-audio main model"
)
logger
.
info
(
f
"Loading kimi-audio main model"
)
if
os
.
path
.
exists
(
model_path
):
# local path
cache_path
=
model_path
else
:
# cache everything if model_path is a model-id
cache_path
=
snapshot_download
(
model_path
)
logger
.
info
(
f
"Looking for resources in
{
cache_path
}
"
)
logger
.
info
(
f
"Loading whisper model"
)
self
.
alm
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
alm
=
AutoModelForCausalLM
.
from_pretrained
(
model
_path
,
torch_dtype
=
torch
.
bfloat16
,
trust_remote_code
=
True
cache
_path
,
torch_dtype
=
torch
.
bfloat16
,
trust_remote_code
=
True
)
)
self
.
alm
=
self
.
alm
.
to
(
torch
.
cuda
.
current_device
())
self
.
alm
=
self
.
alm
.
to
(
torch
.
cuda
.
current_device
())
...
@@ -23,18 +33,9 @@ class KimiAudio(object):
...
@@ -23,18 +33,9 @@ class KimiAudio(object):
self
.
kimia_token_offset
=
model_config
.
kimia_token_offset
self
.
kimia_token_offset
=
model_config
.
kimia_token_offset
self
.
prompt_manager
=
KimiAPromptManager
(
self
.
prompt_manager
=
KimiAPromptManager
(
model_path
=
model
_path
,
kimia_token_offset
=
self
.
kimia_token_offset
model_path
=
cache
_path
,
kimia_token_offset
=
self
.
kimia_token_offset
)
)
if
os
.
path
.
exists
(
model_path
):
# local path
cache_path
=
model_path
else
:
# model_id
cache_path
=
cached_assets_path
(
library_name
=
"transformers"
,
namespace
=
model_path
)
if
load_detokenizer
:
if
load_detokenizer
:
logger
.
info
(
f
"Loading detokenizer"
)
logger
.
info
(
f
"Loading detokenizer"
)
# need to compile extension moudules for the first time, it may take several minutes.
# need to compile extension moudules for the first time, it may take several minutes.
...
...
kimia_infer/api/prompt_manager.py
View file @
44de67a3
...
@@ -4,7 +4,6 @@ import os
...
@@ -4,7 +4,6 @@ import os
import
librosa
import
librosa
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
huggingface_hub
import
cached_assets_path
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -13,25 +12,16 @@ from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer
...
@@ -13,25 +12,16 @@ from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer
from
kimia_infer.utils.data
import
KimiAContent
from
kimia_infer.utils.data
import
KimiAContent
from
kimia_infer.utils.special_tokens
import
instantiate_extra_tokens
from
kimia_infer.utils.special_tokens
import
instantiate_extra_tokens
class
KimiAPromptManager
:
class
KimiAPromptManager
:
def
__init__
(
self
,
model_path
:
str
,
kimia_token_offset
:
int
):
def
__init__
(
self
,
model_path
:
str
,
kimia_token_offset
:
int
):
self
.
audio_tokenizer
=
Glm4Tokenizer
(
"THUDM/glm-4-voice-tokenizer"
)
self
.
audio_tokenizer
=
Glm4Tokenizer
(
"THUDM/glm-4-voice-tokenizer"
)
self
.
audio_tokenizer
=
self
.
audio_tokenizer
.
to
(
torch
.
cuda
.
current_device
())
self
.
audio_tokenizer
=
self
.
audio_tokenizer
.
to
(
torch
.
cuda
.
current_device
())
if
os
.
path
.
exists
(
model_path
):
logger
.
info
(
f
"Looking for resources in
{
model_path
}
"
)
# local path
cache_path
=
model_path
else
:
# model_id
cache_path
=
cached_assets_path
(
library_name
=
"transformers"
,
namespace
=
model_path
)
logger
.
info
(
f
"Looking for resources in
{
cache_path
}
"
)
logger
.
info
(
f
"Loading whisper model"
)
logger
.
info
(
f
"Loading whisper model"
)
self
.
whisper_model
=
WhisperEncoder
(
self
.
whisper_model
=
WhisperEncoder
(
os
.
path
.
join
(
cache
_path
,
"whisper-large-v3"
),
mel_batch_size
=
20
os
.
path
.
join
(
model
_path
,
"whisper-large-v3"
),
mel_batch_size
=
20
)
)
self
.
whisper_model
=
self
.
whisper_model
.
to
(
torch
.
cuda
.
current_device
())
self
.
whisper_model
=
self
.
whisper_model
.
to
(
torch
.
cuda
.
current_device
())
self
.
whisper_model
=
self
.
whisper_model
.
bfloat16
()
self
.
whisper_model
=
self
.
whisper_model
.
bfloat16
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment