Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
dfe09b05
Unverified
Commit
dfe09b05
authored
Mar 15, 2025
by
Jiaqi Liao
Committed by
GitHub
Mar 15, 2025
Browse files
Merge pull request #897 from SkqLiao/main
Add Unit Test for Local Chat
parents
8320ae7d
c66ca657
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
183 additions
and
6 deletions
+183
-6
.github/workflows/install.yml
.github/workflows/install.yml
+12
-6
ktransformers/local_chat_test.py
ktransformers/local_chat_test.py
+171
-0
No files found.
.github/workflows/install.yml
View file @
dfe09b05
name
:
Install
and
Test KTransformers
run-name
:
Install
and
Test KTransformers
name
:
Install
/
Test KTransformers
run-name
:
Install
/
Test KTransformers
on
:
workflow_dispatch
:
inputs
:
job_to_run
:
description
:
"
Which
job
to
run?"
required
:
true
default
:
"
install&
test"
default
:
"
test"
type
:
choice
options
:
-
create-install-test
...
...
@@ -52,14 +52,20 @@ jobs:
git submodule init
git submodule update
bash install.sh
-
name
:
Test Local Chat
-
name
:
Test Local Chat
1
run
:
|
set -e
source /home/qujing3/anaconda3/etc/profile.d/conda.sh
conda activate ktransformers-dev
export PATH=/usr/local/cuda-12.4/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH
export CUDA_HOME=/usr/local/cuda-12.4
cd ${{ github.workspace }}
python ktransformers/local_chat.py --model_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/config --gguf_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/ --max_new_tokens 100 --cache_len 1536 --cpu_infer 64 --prompt_file /home/qujing3/promptsbook.txt
DeepSeek-R1-Q4_K_M/config --gguf_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/ --max_new_tokens 100 --cache_len 1536 --cpu_infer 64 --prompt_file /home/qujing3/prompts/chinese.txt
echo "Running Local Chat 1 (book.txt) ..."
python ktransformers/local_chat_test.py --model_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/config --gguf_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/ --max_new_tokens 256 --cpu_infer 64 --prompt_file /home/qujing3/prompts/book.txt > log1.txt
sed -n '/Prompt:/,$p' log1.txt
echo "Running Local Chat 2 [force think] (chinese.txt) ..."
python ktransformers/local_chat_test.py --model_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/config --gguf_path /home/qujing3/models/DeepSeek-R1-Q4_K_M/ --max_new_tokens 256 --cpu_infer 64 --prompt_file /home/qujing3/prompts/chinese.txt -f > log2.txt
sed -n '/Prompt:/,$p' log2.txt
-
run
:
echo "This job's status is ${{ job.status }}."
ktransformers/local_chat_test.py
0 → 100644
View file @
dfe09b05
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV3ForCausalLM"
:
DeepseekV3ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
}
ktransformer_rules_dir
=
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
)
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
"LlamaForCausalLM"
:
ktransformer_rules_dir
+
"Internlm2_5-7b-Chat-1m.yaml"
,
"MixtralForCausalLM"
:
ktransformer_rules_dir
+
"Mixtral.yaml"
,
}
def
local_chat
(
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
(
"Qwen2Moe"
in
config
.
architectures
[
0
]
):
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
if
"Llama"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"eager"
if
"Mixtral"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_config_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_config_path
,
gguf_path
,
config
)
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
temperature
=
0.6
,
top_p
=
0.95
,
do_sample
=
True
)
model
.
generation_config
=
gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
system
==
"Windows"
:
os
.
system
(
"cls"
)
else
:
os
.
system
(
"clear"
)
if
prompt_file
!=
None
:
assert
os
.
path
.
isfile
(
prompt_file
),
"prompt file not exist"
print
(
f
"prompt file is
{
prompt_file
}
"
)
content
=
open
(
prompt_file
,
"r"
).
read
()
else
:
content
=
"Please write a piece of quicksort code in C++."
print
(
'Start Testing...(1 round)'
)
print
(
'Prompt:'
,
content
)
while
True
:
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
if
force_think
:
token_thinks
=
torch
.
tensor
([
tokenizer
.
encode
(
"<think>
\\
n"
,
add_special_tokens
=
False
)],
device
=
input_tensor
.
device
)
input_tensor
=
torch
.
cat
(
[
input_tensor
,
token_thinks
],
dim
=
1
)
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
break
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment