Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
0f1684c2
Commit
0f1684c2
authored
Mar 15, 2025
by
SkqLiao
Browse files
local chat for cicd test
parent
12949c8a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
171 additions
and
0 deletions
+171
-0
ktransformers/local_chat_test.py
ktransformers/local_chat_test.py
+171
-0
No files found.
ktransformers/local_chat_test.py
0 → 100644
View file @
0f1684c2
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV3ForCausalLM"
:
DeepseekV3ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
}
ktransformer_rules_dir
=
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
)
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
"LlamaForCausalLM"
:
ktransformer_rules_dir
+
"Internlm2_5-7b-Chat-1m.yaml"
,
"MixtralForCausalLM"
:
ktransformer_rules_dir
+
"Mixtral.yaml"
,
}
def
local_chat
(
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
(
"Qwen2Moe"
in
config
.
architectures
[
0
]
):
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
if
"Llama"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"eager"
if
"Mixtral"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_config_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_config_path
,
gguf_path
,
config
)
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
temperature
=
0.6
,
top_p
=
0.95
,
do_sample
=
True
)
model
.
generation_config
=
gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
system
==
"Windows"
:
os
.
system
(
"cls"
)
else
:
os
.
system
(
"clear"
)
if
prompt_file
!=
None
:
assert
os
.
path
.
isfile
(
prompt_file
),
"prompt file not exist"
print
(
f
"prompt file is
{
prompt_file
}
"
)
content
=
open
(
prompt_file
,
"r"
).
read
()
else
:
content
=
"Please write a piece of quicksort code in C++."
print
(
'Start Testing...(1 round)'
)
print
(
'Prompt:'
,
content
)
while
True
:
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
if
force_think
:
token_thinks
=
torch
.
tensor
([
tokenizer
.
encode
(
"<think>
\\
n"
,
add_special_tokens
=
False
)],
device
=
input_tensor
.
device
)
input_tensor
=
torch
.
cat
(
[
input_tensor
,
token_thinks
],
dim
=
1
)
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
break
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment