Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
25cee581
Commit
25cee581
authored
Mar 31, 2025
by
Atream
Browse files
add balance-serve, support concurrence
parent
8d0292aa
Changes
196
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1026 additions
and
62 deletions
+1026
-62
ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py
...ence/sampling/penaltylib/penalizers/repetition_penalty.py
+83
-0
ktransformers/server/balance_serve/inference/sampling/sampler.py
...ormers/server/balance_serve/inference/sampling/sampler.py
+100
-0
ktransformers/server/balance_serve/sched_rpc.py
ktransformers/server/balance_serve/sched_rpc.py
+213
-0
ktransformers/server/balance_serve/settings.py
ktransformers/server/balance_serve/settings.py
+76
-0
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+23
-10
ktransformers/server/main.py
ktransformers/server/main.py
+31
-12
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+4
-2
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-3
ktransformers/server/utils/create_interface.py
ktransformers/server/utils/create_interface.py
+6
-3
ktransformers/tests/mmlu_test_multi.py
ktransformers/tests/mmlu_test_multi.py
+155
-0
ktransformers/tests/test_client.py
ktransformers/tests/test_client.py
+115
-0
ktransformers/tests/test_speed.py
ktransformers/tests/test_speed.py
+146
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+20
-3
merge_tensors/merge_safetensor_gguf.py
merge_tensors/merge_safetensor_gguf.py
+1
-1
requirements-local_chat.txt
requirements-local_chat.txt
+1
-1
setup.py
setup.py
+49
-27
No files found.
ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py
0 → 100644
View file @
25cee581
import
typing
import
torch
from
..orchestrator
import
_BatchedPenalizer
,
_TokenIDs
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties
:
torch
.
Tensor
=
None
cumulated_repetition_penalties
:
torch
.
Tensor
=
None
def
_is_required
(
self
)
->
bool
:
return
any
(
req
.
sampling_params
.
repetition_penalty
!=
1.0
for
req
in
self
.
orchestrator
.
reqs
()
)
def
_prepare
(
self
):
self
.
cumulated_repetition_penalties
=
(
torch
.
tensor
(
data
=
[
1.0
for
_
in
self
.
orchestrator
.
reqs
()],
dtype
=
torch
.
float32
,
device
=
self
.
orchestrator
.
device
,
)
.
unsqueeze_
(
1
)
.
repeat
(
1
,
self
.
orchestrator
.
vocab_size
)
)
self
.
repetition_penalties
=
(
torch
.
tensor
(
data
=
[
req
.
sampling_params
.
repetition_penalty
for
req
in
self
.
orchestrator
.
reqs
()
],
dtype
=
torch
.
float32
,
device
=
self
.
orchestrator
.
device
,
)
.
unsqueeze_
(
1
)
.
expand_as
(
self
.
cumulated_repetition_penalties
)
)
def
_teardown
(
self
):
del
self
.
repetition_penalties
del
self
.
cumulated_repetition_penalties
self
.
repetition_penalties
=
None
self
.
cumulated_repetition_penalties
=
None
def
_cumulate_input_tokens
(
self
,
input_ids
:
_TokenIDs
):
mask
=
input_ids
.
occurrence_count
()
>
0
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_cumulate_output_tokens
(
self
,
output_ids
:
_TokenIDs
):
mask
=
output_ids
.
occurrence_count
()
>
0
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
def
_filter
(
self
,
indices_to_keep
:
typing
.
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
self
.
cumulated_repetition_penalties
=
self
.
cumulated_repetition_penalties
[
indices_tensor_to_keep
]
def
_merge
(
self
,
their
:
"BatchedRepetitionPenalizer"
):
self
.
repetition_penalties
=
torch
.
cat
(
[
self
.
repetition_penalties
,
their
.
repetition_penalties
],
dim
=
0
)
self
.
cumulated_repetition_penalties
=
torch
.
cat
(
[
self
.
cumulated_repetition_penalties
,
their
.
cumulated_repetition_penalties
],
dim
=
0
,
)
ktransformers/server/balance_serve/inference/sampling/sampler.py
0 → 100644
View file @
25cee581
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import
logging
import
torch
from
torch
import
nn
from
transformers
import
GenerationConfig
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_probs
,
top_k_top_p_sampling_from_logits
,
top_p_renorm_probs
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
SamplingOptions
():
# Batched sampling params
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
# All requests use greedy sampling
is_all_greedy
:
bool
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
def
__init__
(
self
,
bsz
=
1
,
device
=
torch
.
device
(
'cuda'
),
pretrained_config
:
GenerationConfig
=
None
,
temperatures
:
torch
.
Tensor
=
None
,
top_ps
:
torch
.
Tensor
=
None
):
if
pretrained_config
is
None
and
temperatures
is
None
:
self
.
temperatures
=
torch
.
full
((
bsz
,
1
),
0
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ps
=
torch
.
ones
((
bsz
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ks
=
torch
.
ones
((
bsz
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
self
.
need_min_p_sampling
=
False
self
.
is_all_greedy
=
True
else
:
if
temperatures
is
not
None
:
self
.
temperatures
=
temperatures
.
unsqueeze
(
-
1
)
else
:
self
.
temperatures
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
temperature
,
device
=
device
,
dtype
=
torch
.
float32
)
if
top_ps
is
not
None
:
self
.
top_ps
=
top_ps
.
unsqueeze
(
-
1
)
else
:
self
.
top_ps
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
top_p
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ks
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
top_k
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
need_min_p_sampling
=
False
self
.
is_all_greedy
=
False
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_config
:
SamplingOptions
=
None
,
):
if
sampling_config
==
None
:
sampling_config
=
SamplingOptions
()
logits
=
logits
.
contiguous
()
origin_logits
=
logits
.
clone
()
if
sampling_config
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
probs
=
logits
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
else
:
# Post process logits
logits
.
div_
(
sampling_config
.
temperatures
)
max_top_k_round
,
batch_size
=
32
,
logits
.
shape
[
0
]
if
sampling_config
.
need_min_p_sampling
:
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
None
del
logits
probs
=
top_k_renorm_probs
(
probs
,
sampling_config
.
top_ks
)
probs
=
top_p_renorm_probs
(
probs
,
sampling_config
.
top_ps
)
batch_next_token_ids
=
min_p_sampling_from_probs
(
probs
,
sampling_config
.
min_ps
)
temperature_0_idx
=
torch
.
where
(
sampling_config
.
temperatures
==
0
)[
0
]
batch_next_token_ids
[
temperature_0_idx
]
=
torch
.
argmax
(
origin_logits
[
temperature_0_idx
],
-
1
).
to
(
torch
.
int32
)
else
:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs
=
logits
batch_next_token_ids
=
top_k_top_p_sampling_from_logits
(
logits
,
sampling_config
.
top_ks
,
sampling_config
.
top_ps
,
filter_apply_order
=
"joint"
,
)
temperature_0_idx
=
torch
.
where
(
sampling_config
.
temperatures
==
0
)[
0
]
batch_next_token_ids
[
temperature_0_idx
]
=
torch
.
argmax
(
origin_logits
[
temperature_0_idx
],
-
1
).
to
(
torch
.
int32
)
return
batch_next_token_ids
.
to
(
torch
.
int32
),
probs
\ No newline at end of file
ktransformers/server/balance_serve/sched_rpc.py
0 → 100644
View file @
25cee581
from
datetime
import
datetime
import
os
from
typing
import
Optional
import
zmq
import
pickle
import
threading
import
torch.multiprocessing
as
mp
import
sys
current_file_path
=
os
.
path
.
abspath
(
__file__
)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import
pickle
import
argparse
from
ktransformers.server.balance_serve.settings
import
sched_ext
,
create_sched_settings
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
print
(
'set start method'
)
mp
.
set_start_method
(
'spawn'
)
else
:
print
(
f
'start method already set to
{
mp
.
get_start_method
(
allow_none
=
True
)
}
'
)
class
SchedulerServer
:
def
__init__
(
self
,
settings
,
main_args
):
# 创建 Scheduler 实例并初始化
self
.
sched
=
sched_ext
.
create_scheduler
(
settings
)
# 初始化 ZeroMQ 上下文和套接字
self
.
context
=
zmq
.
Context
()
self
.
frontend
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
print
(
f
"sched zmq rpc server on port
{
main_args
.
sched_port
}
"
)
self
.
frontend
.
bind
(
f
"tcp://*:
{
main_args
.
sched_port
}
"
)
# 创建内部的 DEALER 套接字,用于与工作线程通信
self
.
backend
=
self
.
context
.
socket
(
zmq
.
DEALER
)
self
.
backend
.
bind
(
"inproc://backend"
)
# 启动调度器
def
run_scheduler
(
self
):
self
.
sched
.
run
()
# 停止调度器
def
stop_scheduler
(
self
):
self
.
sched
.
stop
()
# 处理客户端请求
def
start_proxy
(
self
):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq
.
proxy
(
self
.
frontend
,
self
.
backend
)
# 工作线程处理请求
def
worker_routine
(
self
):
worker
=
self
.
context
.
socket
(
zmq
.
REP
)
worker
.
connect
(
"inproc://backend"
)
while
True
:
try
:
# 接收客户端请求
message
=
worker
.
recv
()
data
=
pickle
.
loads
(
message
)
method
=
data
.
get
(
'method'
)
params
=
data
.
get
(
'params'
,
{})
# print(f"Received request: {method}")
if
method
==
'add_query'
:
query_add
=
params
.
get
(
'query'
)
# 直接是一个 QueryAdd 对象
# 添加查询
query_id
=
self
.
sched
.
add_query
(
query_add
)
# 发送响应
response
=
{
'status'
:
'ok'
,
'query_id'
:
query_id
}
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'cancel_query'
:
query_id
=
params
.
get
(
'query_id'
)
# 假设您的 Scheduler 类实现了 cancel 方法
self
.
sched
.
cancel
(
query_id
)
response
=
{
'status'
:
'ok'
}
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'update_last_batch'
:
updates
=
params
.
get
(
'updates'
)
# 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo
=
self
.
sched
.
update_last_batch
(
updates
)
# 直接发送 batch_todo 对象
response
=
{
'status'
:
'ok'
,
'batch_todo'
:
batch_todo
}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'get_inference_context'
:
inference_context
=
self
.
sched
.
get_inference_context
()
data
=
{
"k_cache"
:
inference_context
.
k_cache
,
"v_cache"
:
inference_context
.
v_cache
}
print
(
f
"Serializing KVCache"
)
data
[
"k_cache"
]
=
[
mp
.
reductions
.
reduce_tensor
(
t
)
for
t
in
data
[
'k_cache'
]]
data
[
"v_cache"
]
=
[
mp
.
reductions
.
reduce_tensor
(
t
)
for
t
in
data
[
'v_cache'
]]
# print(data)
response
=
{
'status'
:
'ok'
,
'inference_context'
:
data
}
worker
.
send
(
pickle
.
dumps
(
response
))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else
:
# 未知方法
response
=
{
'status'
:
'error'
,
'message'
:
'Unknown method'
}
worker
.
send
(
pickle
.
dumps
(
response
))
except
Exception
as
e
:
# 处理异常并发送错误响应
response
=
{
'status'
:
'error'
,
'message'
:
str
(
e
)}
worker
.
send
(
pickle
.
dumps
(
response
))
# 启动 RPC 服务
def
start_rpc_service
(
self
):
try
:
print
(
"Scheduler RPC service is running..."
)
# 在单独的线程中运行调度器
threading
.
Thread
(
target
=
self
.
run_scheduler
,
daemon
=
True
).
start
()
# 启动工作线程
for
_
in
range
(
10
):
# 根据需要调整线程数
threading
.
Thread
(
target
=
self
.
worker_routine
,
daemon
=
True
).
start
()
# 启动代理,开始监听请求
self
.
start_proxy
()
except
KeyboardInterrupt
:
print
(
"Shutting down scheduler RPC service..."
)
self
.
stop_rpc_service
()
# 停止 RPC 服务
def
stop_rpc_service
(
self
):
self
.
stop_scheduler
()
self
.
frontend
.
close
()
self
.
backend
.
close
()
self
.
context
.
term
()
def
start_server
(
settings
,
main_args
):
server
=
SchedulerServer
(
settings
,
main_args
)
server
.
start_rpc_service
()
# Add async client for webserver
class
SchedulerClient
:
def
__init__
(
self
,
sched_port
):
address
=
f
'tcp://localhost:
{
sched_port
}
'
self
.
address
=
address
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
REQ
)
self
.
socket
.
connect
(
self
.
address
)
print
(
f
"Connected to server at
{
self
.
address
}
"
)
def
__del__
(
self
):
self
.
socket
.
close
()
self
.
context
.
term
()
def
send_request
(
self
,
method
,
params
=
None
):
if
params
is
None
:
params
=
{}
request
=
{
'method'
:
method
,
'params'
:
params
}
# print(f'send request {request}')
self
.
socket
.
send
(
pickle
.
dumps
(
request
))
response
=
self
.
socket
.
recv
()
# print(response)
response
=
pickle
.
loads
(
response
)
if
response
.
get
(
'status'
)
==
'ok'
:
return
response
else
:
raise
Exception
(
f
"Error from server:
{
response
.
get
(
'message'
)
}
"
)
def
add_query
(
self
,
query
):
response
=
self
.
send_request
(
'add_query'
,
{
'query'
:
query
})
return
response
.
get
(
'query_id'
)
def
cancel_query
(
self
,
query_id
):
self
.
send_request
(
'cancel_query'
,
{
'query_id'
:
query_id
})
def
update_last_batch
(
self
,
updates
):
response
=
self
.
send_request
(
'update_last_batch'
,
{
'updates'
:
updates
})
# print(f"update_last_batch response {response}")
return
response
.
get
(
'batch_todo'
)
def
rebuild_inferece_context
(
self
,
response
):
data
=
response
.
get
(
'inference_context'
)
inference_context
=
sched_ext
.
InferenceContext
()
print
(
'Rebuilding kvcache'
)
inference_context
.
k_cache
=
[
fn
(
*
args
)
for
fn
,
args
in
data
[
'k_cache'
]]
inference_context
.
v_cache
=
[
fn
(
*
args
)
for
fn
,
args
in
data
[
'v_cache'
]]
return
inference_context
def
get_inference_context_raw
(
self
):
response
=
self
.
send_request
(
'get_inference_context'
)
return
response
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
"rb"
)
as
f
:
main_args
=
pickle
.
load
(
f
)
settings
=
create_sched_settings
(
main_args
)
start_server
(
settings
,
main_args
)
ktransformers/server/balance_serve/settings.py
0 → 100644
View file @
25cee581
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import
sys
,
os
import
yaml
,
json
from
time
import
sleep
current_dir
=
os
.
path
.
dirname
(
__file__
)
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
# sys.path.insert(0, sched_path)
import
sched_ext
from
transformers
import
AutoConfig
def
create_sched_settings
(
args
):
default_sample_options
=
sched_ext
.
SampleOptions
()
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
input_model_settings
=
sched_ext
.
ModelSettings
()
input_model_settings
.
model_path
=
args
.
model_dir
input_model_settings
.
params_count
=
int
(
0
)
model_config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
input_model_settings
.
layer_count
=
model_config
.
num_hidden_layers
input_model_settings
.
num_k_heads
=
1
# model_config["num_key_value_heads"]
input_model_settings
.
k_head_dim
=
576
input_model_settings
.
bytes_per_params
=
2
input_model_settings
.
bytes_per_kv_cache_element
=
2
settings
=
sched_ext
.
Settings
()
settings
.
model_name
=
model_name
settings
.
quant_type
=
"BF16"
settings
.
model_settings
=
input_model_settings
settings
.
page_size
=
args
.
page_size
settings
.
gpu_device_count
=
1
# tp
settings
.
gpu_device_id
=
[
i
for
i
in
range
(
settings
.
gpu_device_count
)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings
.
gpu_memory_size
=
args
.
gpu_memory_size
settings
.
memory_utilization_percentage
=
args
.
utilization_percentage
max_batch_size
=
args
.
max_batch_size
chunk_size
=
args
.
chunk_size
max_decode_batch_size
=
max_batch_size
-
2
settings
.
max_batch_size
=
max_batch_size
settings
.
recommended_chunk_prefill_token_count
=
(
chunk_size
-
max_decode_batch_size
)
//
2
settings
.
sample_options
=
default_sample_options
settings
.
sched_metrics_port
=
args
.
sched_metrics_port
settings
.
gpu_only
=
args
.
memory_gpu_only
settings
.
use_self_defined_head_dim
=
True
settings
.
self_defined_head_dim
=
576
settings
.
full_kv_cache_on_each_gpu
=
True
settings
.
k_cache_on
=
True
settings
.
v_cache_on
=
False
settings
.
kvc2_root_path
=
'/mnt/data/persist-kvc'
settings
.
kvc2_config_path
=
os
.
path
.
join
(
current_dir
,
".."
,
".."
,
"configs"
)
print
(
os
.
path
.
join
(
current_dir
,
".."
,
".."
,
"configs"
))
settings
.
memory_pool_size_GB
=
args
.
cpu_memory_size_GB
settings
.
evict_count
=
40
settings
.
kvc2_metrics_port
=
args
.
kvc2_metrics_port
settings
.
load_from_disk
=
False
settings
.
save_to_disk
=
True
settings
.
strategy_name
=
args
.
sched_strategy
settings
.
auto_derive
()
return
settings
ktransformers/server/config/config.py
View file @
25cee581
...
@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
...
@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import
os
import
os
import
shutil
import
shutil
import
yaml
import
yaml
import
psutil
from
ktransformers.server.config.singleton
import
Singleton
from
ktransformers.server.config.singleton
import
Singleton
from
typing
import
Optional
from
typing
import
Optional
...
@@ -60,7 +61,7 @@ class Config(metaclass=Singleton):
...
@@ -60,7 +61,7 @@ class Config(metaclass=Singleton):
self
.
user_path
:
str
=
os
.
path
.
expanduser
(
"~"
)
self
.
user_path
:
str
=
os
.
path
.
expanduser
(
"~"
)
self
.
localstore_path
:
str
=
os
.
path
.
join
(
self
.
user_path
,
".ktransformers"
)
self
.
localstore_path
:
str
=
os
.
path
.
join
(
self
.
user_path
,
".ktransformers"
)
# log configs
# log configs
self
.
log_dir
=
os
.
path
.
join
(
self
.
base_path
,
Config
.
to_path
(
cfg
[
"log"
][
"dir"
])
)
self
.
log_dir
=
os
.
path
.
join
(
self
.
localstore_path
,
cfg
[
"log"
][
"dir"
])
self
.
log_file
=
cfg
[
"log"
][
"file"
]
self
.
log_file
=
cfg
[
"log"
][
"file"
]
self
.
log_level
=
cfg
[
"log"
][
"level"
]
self
.
log_level
=
cfg
[
"log"
][
"level"
]
self
.
backup_count
=
cfg
[
"log"
][
"backup_count"
]
self
.
backup_count
=
cfg
[
"log"
][
"backup_count"
]
...
@@ -74,7 +75,7 @@ class Config(metaclass=Singleton):
...
@@ -74,7 +75,7 @@ class Config(metaclass=Singleton):
# db configs
# db configs
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
self
.
db_type
=
self
.
db_configs
.
get
(
"type"
,
""
)
self
.
db_type
=
self
.
db_configs
.
get
(
"type"
,
""
)
self
.
db_host
=
os
.
path
.
join
(
self
.
base
_path
,
self
.
db_configs
.
get
(
"host"
,
""
))
self
.
db_host
=
Config
.
to
_path
(
self
.
db_configs
.
get
(
"host"
,
""
))
self
.
db_port
=
self
.
db_configs
.
get
(
"port"
,
""
)
self
.
db_port
=
self
.
db_configs
.
get
(
"port"
,
""
)
self
.
db_name
=
self
.
db_configs
.
get
(
"database"
,
""
)
self
.
db_name
=
self
.
db_configs
.
get
(
"database"
,
""
)
self
.
db_pool_size
=
self
.
db_configs
.
get
(
"pool_size"
)
self
.
db_pool_size
=
self
.
db_configs
.
get
(
"pool_size"
)
...
@@ -101,11 +102,6 @@ class Config(metaclass=Singleton):
...
@@ -101,11 +102,6 @@ class Config(metaclass=Singleton):
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"optimize_config_path"
,
None
"optimize_config_path"
,
None
)
)
self
.
paged
=
self
.
model
.
get
(
"paged"
,
True
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
...
@@ -138,7 +134,6 @@ class Config(metaclass=Singleton):
...
@@ -138,7 +134,6 @@ class Config(metaclass=Singleton):
self
.
repetition_penalty
=
self
.
model
.
get
(
"repetition_penalty"
,
1.01
)
self
.
repetition_penalty
=
self
.
model
.
get
(
"repetition_penalty"
,
1.01
)
self
.
frequency_penalty
=
self
.
model
.
get
(
"frequency_penalty"
,
0.0
)
self
.
frequency_penalty
=
self
.
model
.
get
(
"frequency_penalty"
,
0.0
)
self
.
presence_penalty
=
self
.
model
.
get
(
"presence_penalty"
,
0.0
)
self
.
presence_penalty
=
self
.
model
.
get
(
"presence_penalty"
,
0.0
)
self
.
max_response_tokens
=
self
.
model
.
get
(
"max_response_tokens"
,
300
)
self
.
response_chunk
=
self
.
model
.
get
(
"response_chunk"
,
250
)
self
.
response_chunk
=
self
.
model
.
get
(
"response_chunk"
,
250
)
self
.
no_code_formatting
=
self
.
model
.
get
(
"no_code_formatting"
,
False
)
self
.
no_code_formatting
=
self
.
model
.
get
(
"no_code_formatting"
,
False
)
self
.
cache_8bit
=
self
.
model
.
get
(
"cache_8bit"
,
False
)
self
.
cache_8bit
=
self
.
model
.
get
(
"cache_8bit"
,
False
)
...
@@ -155,8 +150,9 @@ class Config(metaclass=Singleton):
...
@@ -155,8 +150,9 @@ class Config(metaclass=Singleton):
self
.
web_cross_domain
:
bool
=
self
.
web
.
get
(
"open_cross_domain"
,
True
)
self
.
web_cross_domain
:
bool
=
self
.
web
.
get
(
"open_cross_domain"
,
True
)
self
.
mount_web
:
bool
=
self
.
web
.
get
(
"mount"
,
False
)
self
.
mount_web
:
bool
=
self
.
web
.
get
(
"mount"
,
False
)
# ext
self
.
ext
:
dict
=
cfg
.
get
(
"ext"
,
{})
self
.
ext
:
dict
=
cfg
.
get
(
"ext"
,
{})
self
.
cpu_infer
=
self
.
ext
.
get
(
"cpu_infer"
,
10
)
self
.
cpu_infer
=
psutil
.
cpu_count
(
logical
=
False
)
-
3
# file config
# file config
self
.
local_store_configs
:
dict
=
cfg
.
get
(
"local_store"
,
{})
self
.
local_store_configs
:
dict
=
cfg
.
get
(
"local_store"
,
{})
...
@@ -169,7 +165,6 @@ class Config(metaclass=Singleton):
...
@@ -169,7 +165,6 @@ class Config(metaclass=Singleton):
# long context config
# long context config
self
.
long_context_config
:
dict
=
cfg
.
get
(
"long_context"
,
{})
self
.
long_context_config
:
dict
=
cfg
.
get
(
"long_context"
,
{})
self
.
chunk_size
=
self
.
long_context_config
.
get
(
"chunk_size"
,
4096
)
self
.
max_seq_len
=
self
.
long_context_config
.
get
(
"max_seq_len"
,
32000
)
self
.
max_seq_len
=
self
.
long_context_config
.
get
(
"max_seq_len"
,
32000
)
self
.
block_size
=
self
.
long_context_config
.
get
(
"block_size"
,
128
)
self
.
block_size
=
self
.
long_context_config
.
get
(
"block_size"
,
128
)
self
.
local_windows_len
=
self
.
long_context_config
.
get
(
"local_windows_len"
,
4096
)
self
.
local_windows_len
=
self
.
long_context_config
.
get
(
"local_windows_len"
,
4096
)
...
@@ -187,3 +182,21 @@ class Config(metaclass=Singleton):
...
@@ -187,3 +182,21 @@ class Config(metaclass=Singleton):
# local chat
# local chat
self
.
local_chat_config
:
dict
=
cfg
.
get
(
"local_chat"
,
{})
self
.
local_chat_config
:
dict
=
cfg
.
get
(
"local_chat"
,
{})
self
.
prompt_file
=
self
.
local_chat_config
.
get
(
"prompt_file"
,
None
)
self
.
prompt_file
=
self
.
local_chat_config
.
get
(
"prompt_file"
,
None
)
# asyncserver
self
.
sched_strategy
=
cfg
[
'async_server'
][
'sched_strategy'
]
self
.
sched_port
=
cfg
[
'async_server'
][
'sched_port'
]
self
.
sched_metrics_port
=
cfg
[
'async_server'
][
'sched_metrics_port'
]
self
.
kvc2_metrics_port
=
cfg
[
'async_server'
][
'kvc2_metrics_port'
]
self
.
max_batch_size
=
cfg
[
'async_server'
][
'max_batch_size'
]
self
.
page_size
=
cfg
[
'attn'
][
'page_size'
]
self
.
chunk_size
=
cfg
[
'attn'
][
'chunk_size'
]
self
.
memory_gpu_only
=
cfg
[
'kvc2'
][
'gpu_only'
]
self
.
cache_lens
=
((
self
.
cache_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
)
*
self
.
page_size
self
.
gpu_memory_size
=
2
*
576
*
61
*
self
.
cache_lens
self
.
utilization_percentage
=
1.0
#cfg['kvc2']['utilization_percentage']
self
.
cpu_memory_size_GB
=
cfg
[
'kvc2'
][
'cpu_memory_size_GB'
]
# only support 2 prefill task
self
.
max_prefill_batch_size
=
2
self
.
max_decode_batch_size
=
self
.
max_batch_size
-
self
.
max_prefill_batch_size
ktransformers/server/main.py
View file @
25cee581
...
@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
...
@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import
uvicorn.logging
import
uvicorn.logging
import
uvicorn
import
uvicorn
import
sys
import
sys
import
atexit
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
sys
.
path
.
insert
(
0
,
project_dir
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.args
import
ArgumentParser
from
ktransformers.server.args
import
ArgumentParser
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.utils.create_interface
import
create_interface
from
ktransformers.server.utils.create_interface
import
create_interface
,
GlobalInterface
from
ktransformers.server.backend.args
import
default_args
from
fastapi.openapi.utils
import
get_openapi
from
fastapi.openapi.utils
import
get_openapi
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.api
import
router
,
post_db_creation_operations
from
ktransformers.server.api
import
router
,
post_db_creation_operations
from
ktransformers.server.utils.sql_utils
import
Base
,
SQLUtil
from
ktransformers.server.utils.sql_utils
import
Base
,
SQLUtil
from
ktransformers.server.config.log
import
logger
from
ktransformers.server.config.log
import
logger
import
subprocess
import
tempfile
def
mount_app_routes
(
mount_app
:
FastAPI
):
def
mount_app_routes
(
mount_app
:
FastAPI
):
sql_util
=
SQLUtil
()
sql_util
=
SQLUtil
()
...
@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
...
@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def
create_app
():
def
create_app
():
cfg
=
Config
()
cfg
=
Config
()
app
=
FastAPI
()
if
(
hasattr
(
GlobalInterface
.
interface
,
"lifespan"
)):
app
=
FastAPI
(
lifespan
=
GlobalInterface
.
interface
.
lifespan
)
else
:
app
=
FastAPI
()
if
Config
().
web_cross_domain
:
if
Config
().
web_cross_domain
:
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
...
@@ -108,11 +107,32 @@ def main():
...
@@ -108,11 +107,32 @@ def main():
arg_parser
=
ArgumentParser
(
cfg
)
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
args
=
arg_parser
.
parse_args
()
args
=
arg_parser
.
parse_args
()
if
args
.
backend_type
==
"balance_serve"
:
import
pickle
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
app
=
create_app
()
app
=
create_app
()
custom_openapi
(
app
)
custom_openapi
(
app
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
run_api
(
run_api
(
app
=
app
,
app
=
app
,
host
=
args
.
host
,
host
=
args
.
host
,
...
@@ -121,6 +141,5 @@ def main():
...
@@ -121,6 +141,5 @@ def main():
ssl_certfile
=
args
.
ssl_certfile
,
ssl_certfile
=
args
.
ssl_certfile
,
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
ktransformers/server/requirements.txt
View file @
25cee581
torch >= 2.3.0
,<=2.3.1
torch >= 2.3.0
transformers == 4.43.2
transformers == 4.43.2
fastapi >= 0.111.0
fastapi >= 0.111.0
langchain >= 0.2.0
langchain >= 0.2.0
...
@@ -11,4 +11,6 @@ build
...
@@ -11,4 +11,6 @@ build
ninja
ninja
wheel
wheel
colorlog
colorlog
fire
fire
\ No newline at end of file
zmq
psutil
\ No newline at end of file
ktransformers/server/schemas/endpoints/chat.py
View file @
25cee581
...
@@ -2,7 +2,7 @@ from typing import List, Optional
...
@@ -2,7 +2,7 @@ from typing import List, Optional
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
enum
import
Enum
from
enum
import
Enum
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
Field
from
ktransformers.server.schemas.base
import
Object
from
ktransformers.server.schemas.base
import
Object
...
@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
...
@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages
:
List
[
Message
]
messages
:
List
[
Message
]
model
:
str
model
:
str
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
temperature
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
top_p
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/utils/create_interface.py
View file @
25cee581
...
@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
...
@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
def
create_interface
(
config
:
Config
,
default_args
:
ConfigArgs
):
def
create_interface
(
config
:
Config
,
default_args
:
ConfigArgs
):
if
config
.
backend_type
==
'transformers'
:
if
config
.
backend_type
==
'transformers'
:
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
as
BackendInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
as
BackendInterface
...
@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
...
@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
as
BackendInterface
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
as
BackendInterface
elif
config
.
backend_type
==
'ktransformers'
:
elif
config
.
backend_type
==
'ktransformers'
:
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
as
BackendInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
as
BackendInterface
elif
config
.
backend_type
==
'balance_serve'
:
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeInterface
as
BackendInterface
else
:
else
:
raise
NotImplementedError
(
f
'
{
config
.
backend_type
}
not implemented'
)
raise
NotImplementedError
(
f
'
{
config
.
backend_type
}
not implemented'
)
GlobalInterface
.
interface
=
BackendInterface
(
default_args
)
GlobalInterface
.
interface
=
BackendInterface
(
default_args
)
...
@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
...
@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class
GlobalContextManager
:
class
GlobalContextManager
:
context_manager
:
ThreadContextManager
context_manager
:
ThreadContextManager
class
GlobalInterface
:
class
GlobalInterface
:
interface
:
TransformersInterface
|
KTransformersInterface
|
ExllamaInterface
interface
:
TransformersInterface
|
KTransformersInterface
|
ExllamaInterface
def
get_thread_context_manager
()
->
Thread
ContextManager
:
def
get_thread_context_manager
()
->
Global
ContextManager
:
return
GlobalContextManager
.
context_manager
return
GlobalContextManager
.
context_manager
def
get_interface
()
->
TransformersInterface
|
KTransformersInterface
|
Exllama
Interface
:
def
get_interface
()
->
Global
Interface
:
return
GlobalInterface
.
interface
return
GlobalInterface
.
interface
\ No newline at end of file
ktransformers/tests/mmlu_test_multi.py
0 → 100644
View file @
25cee581
import
argparse
import
random
import
time
import
json
import
requests
import
pandas
as
pd
from
datasets
import
load_dataset
import
os
import
concurrent.futures
import
threading
os
.
environ
[
'HF_ENDPOINT'
]
=
'https://hf-mirror.com'
os
.
environ
[
'https_proxy'
]
=
''
os
.
environ
[
'http_proxy'
]
=
''
hint
=
'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class
DataEvaluator
:
def
__init__
(
self
):
self
.
data
=
[]
def
load_data
(
self
,
file_path
):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds
=
load_dataset
(
file_path
,
"all"
)
df
=
pd
.
DataFrame
(
ds
[
'test'
])
for
_
,
row
in
df
.
iterrows
():
self
.
data
.
append
(
row
.
to_dict
())
def
get_prompt
(
self
,
record
):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str
=
"
\n
"
.
join
([
f
"
{
chr
(
65
+
i
)
}
.
{
opt
}
"
for
i
,
opt
in
enumerate
(
record
[
'choices'
])])
prompt
=
hint
+
"
\n
Question: "
+
record
[
'question'
]
+
"
\n
"
+
options_str
+
"
\n
Answer: '"
return
prompt
def
post_processing
(
self
,
text
):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text
=
text
.
lstrip
(
'
\n
'
).
split
(
'
\n
'
)[
-
1
]
return
text
[
-
1
:]
def
score
(
self
,
pred
,
answer
):
"""
对比预测答案和正确答案,返回得分
"""
if
pred
==
answer
:
return
1
return
0
def
generate_text
(
api_url
,
question
,
model_name
,
stream
=
False
):
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
,
'Authorization'
:
'Bearer '
# 如有需要,请填入 API Key
}
data
=
{
"messages"
:
[{
"content"
:
question
,
"role"
:
"user"
}],
"model"
:
model_name
,
"stream"
:
stream
,
}
print
(
"POST data:"
,
data
)
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
data
,
timeout
=
5000000
)
if
response
.
status_code
==
200
:
result
=
response
.
json
()
return
result
.
get
(
'choices'
,
[{}])[
0
].
get
(
'message'
,
{}).
get
(
'content'
,
''
).
strip
()
else
:
print
(
f
"API Request failed with status code
{
response
.
status_code
}
"
)
return
None
def
main
(
concurrent_requests
,
data_evaluator
:
DataEvaluator
,
result_file
,
log_file
,
api_url
,
model_name
):
start_total_time
=
time
.
time
()
total_score
=
0
results
=
[]
file_lock
=
threading
.
Lock
()
# 打乱数据顺序,并选择需要测试的实例数
random
.
seed
(
42
)
random
.
shuffle
(
data_evaluator
.
data
)
data_subset
=
data_evaluator
.
data
[:
min
(
concurrent_requests
,
len
(
data_evaluator
.
data
))]
batch_size
=
10
# 每批次最多 10 个实例
def
worker
(
index
,
data_item
):
nonlocal
total_score
question
=
data_evaluator
.
get_prompt
(
data_item
)
start_time
=
time
.
time
()
try
:
prediction
=
generate_text
(
api_url
,
question
,
model_name
)
if
prediction
is
None
:
raise
Exception
(
f
"Failed to get prediction for question:
{
question
}
"
)
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer
=
chr
(
data_item
[
'answer'
]
+
65
)
processed_prediction
=
data_evaluator
.
post_processing
(
prediction
)
score
=
data_evaluator
.
score
(
processed_prediction
,
answer
)
elapsed_time
=
time
.
time
()
-
start_time
result_data
=
{
"question_id"
:
index
,
"answer"
:
answer
,
"prediction"
:
processed_prediction
,
"real_prediction"
:
prediction
,
"score"
:
score
,
"time"
:
elapsed_time
}
# 写入结果时加锁保证线程安全
with
file_lock
:
with
open
(
result_file
,
'a'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
result_data
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
f
.
write
(
"
\n
"
)
return
result_data
except
Exception
as
e
:
print
(
f
"Error processing request
{
index
}
:
{
e
}
"
)
return
None
# 按批次处理,每批最多 10 个任务
for
batch_start
in
range
(
0
,
len
(
data_subset
),
batch_size
):
batch
=
data_subset
[
batch_start
:
batch_start
+
batch_size
]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
batch_size
)
as
executor
:
futures
=
[
executor
.
submit
(
worker
,
batch_start
+
j
,
data_item
)
for
j
,
data_item
in
enumerate
(
batch
)]
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
res
=
future
.
result
()
if
res
is
not
None
:
results
.
append
(
res
)
total_score
+=
res
[
'score'
]
total_time
=
time
.
time
()
-
start_total_time
throughput
=
len
(
data_subset
)
/
total_time
if
total_time
>
0
else
0
with
open
(
log_file
,
'a'
,
encoding
=
'utf-8'
)
as
log_f
:
log_f
.
write
(
f
"Total Time:
{
total_time
:.
2
f
}
seconds
\n
"
)
log_f
.
write
(
f
"Throughput:
{
throughput
:.
2
f
}
requests per second
\n
"
)
average_score
=
total_score
/
len
(
data_subset
)
if
data_subset
else
0
log_f
.
write
(
f
"Average Score:
{
average_score
}
\n
"
)
log_f
.
write
(
'-'
*
40
+
'
\n
'
)
print
(
f
"Results saved to
{
result_file
}
"
)
print
(
f
"Log saved to
{
log_file
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1000
,
help
=
"需要测试的实例总数"
)
parser
.
add_argument
(
"--file"
,
type
=
str
,
default
=
"cais/mmlu"
,
help
=
"数据文件路径"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_result_silicon.json"
,
help
=
"结果文件保存路径"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_result_silicon.log"
,
help
=
"日志文件保存路径"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"模型名称或路径"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10006/v1/chat/completions"
,
help
=
"API URL"
)
args
=
parser
.
parse_args
()
data_evaluator
=
DataEvaluator
()
data_evaluator
.
load_data
(
args
.
file
)
main
(
args
.
concurrent
,
data_evaluator
,
args
.
result
,
args
.
log
,
args
.
api_url
,
args
.
model
)
ktransformers/tests/test_client.py
0 → 100644
View file @
25cee581
import
asyncio
import
json
import
sys
import
aiohttp
import
random
import
argparse
import
yaml
import
os
import
time
from
time
import
sleep
decodesz
=
128
# Server URL (replace with your server URL)
SERVER_URL
=
"http://localhost:10002/v1/chat/completions"
bf_list
=
[
1
]
decodesz_list
=
[
128
]
prompt_list
=
[
'请你介绍下秦始皇'
,
'3.9 和 3.11 哪个大'
,
'抗衰老有何妙招'
,
'给我讲个故事'
]
async
def
fetch_event_stream
(
session
,
request_id
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
request_id
]}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
}
async
with
session
.
post
(
SERVER_URL
,
json
=
payload
,
headers
=
headers
,
timeout
=
50000
)
as
response
:
print
(
f
"Request
{
request_id
}
: Connected, status
{
response
.
status
}
"
)
if
response
.
status
!=
200
:
print
(
f
"Request
{
request_id
}
: Error, status
{
response
.
status
}
"
)
return
output_text
=
""
# 存储当前 response 的所有 token
total_tokens
=
0
# 统计总 tokens 数
decode_start_time
=
None
# 记录 decode 阶段开始时间
decode_end_time
=
None
# 记录 decode 结束时间
async
for
line
in
response
.
content
:
try
:
decoded_line
=
line
.
decode
(
"utf-8"
).
strip
()
# 过滤空行
if
not
decoded_line
or
not
decoded_line
.
startswith
(
"data: "
):
continue
decoded_line
=
decoded_line
[
6
:].
strip
()
# 去掉 `data: `
# 确保 JSON 数据是合法的
if
not
decoded_line
:
continue
response_data
=
json
.
loads
(
decoded_line
)
# 解析 JSON
# 确保 choices 存在
choices
=
response_data
.
get
(
"choices"
,
[])
if
not
choices
:
continue
delta
=
choices
[
0
].
get
(
"delta"
,
{})
token
=
delta
.
get
(
"content"
,
""
)
if
token
:
if
decode_start_time
is
None
:
decode_start_time
=
time
.
time
()
# 记录 decode 开始时间
output_text
+=
token
# 追加 token
sys
.
stdout
.
write
(
token
)
# 直接输出 token
sys
.
stdout
.
flush
()
# 立即刷新,确保 token 立刻出现在终端
total_tokens
+=
1
# 增加 token 计数
decode_end_time
=
time
.
time
()
# 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason
=
choices
[
0
].
get
(
"finish_reason"
,
None
)
if
finish_reason
:
# print(f"\nRequest {request_id}: Done")
break
# 结束流式处理
except
json
.
JSONDecodeError
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: JSON Decode Error -
{
e
}
"
)
except
IndexError
:
print
(
f
"
\n
Request
{
request_id
}
: List Index Error - choices is empty"
)
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Error parsing stream -
{
e
}
"
)
# 计算 decode 速度
if
decode_start_time
and
decode_end_time
and
total_tokens
>
0
:
decode_time
=
decode_end_time
-
decode_start_time
decode_speed
=
total_tokens
/
decode_time
if
decode_time
>
0
else
0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Exception -
{
e
}
"
)
async
def
main
(
prompt_id
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--question_id"
,
type
=
int
,
default
=
0
,
required
=
False
)
args
=
parser
.
parse_args
()
output_file
=
"ktransformer_test_results.txt"
asyncio
.
run
(
main
(
args
.
question_id
))
ktransformers/tests/test_speed.py
0 → 100644
View file @
25cee581
import
asyncio
import
json
import
sys
import
aiohttp
import
random
import
argparse
import
yaml
import
os
import
time
from
time
import
sleep
decodesz
=
128
# Server URL (replace with your server URL)
decodesz_list
=
[
128
]
ktansformer_prompt1024
=
"""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。想。
请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字"""
async
def
fetch_event_stream
(
session
,
request_id
,
prompt
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt
}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
}
async
with
session
.
post
(
SERVER_URL
,
json
=
payload
,
headers
=
headers
,
timeout
=
500000
)
as
response
:
print
(
f
"Request
{
request_id
}
: Connected, status
{
response
.
status
}
"
)
if
response
.
status
!=
200
:
print
(
f
"Request
{
request_id
}
: Error, status
{
response
.
status
}
"
)
return
output_text
=
""
# 存储当前 response 的所有 token
total_tokens
=
0
# 统计总 tokens 数
decode_start_time
=
None
# 记录 decode 阶段开始时间
decode_end_time
=
None
# 记录 decode 结束时间
async
for
line
in
response
.
content
:
try
:
decoded_line
=
line
.
decode
(
"utf-8"
).
strip
()
# 过滤空行
if
not
decoded_line
or
not
decoded_line
.
startswith
(
"data: "
):
continue
decoded_line
=
decoded_line
[
6
:].
strip
()
# 去掉 `data: `
# 确保 JSON 数据是合法的
if
not
decoded_line
:
continue
response_data
=
json
.
loads
(
decoded_line
)
# 解析 JSON
# 确保 choices 存在
choices
=
response_data
.
get
(
"choices"
,
[])
if
not
choices
:
continue
delta
=
choices
[
0
].
get
(
"delta"
,
{})
token
=
delta
.
get
(
"content"
,
""
)
if
token
:
if
decode_start_time
is
None
:
decode_start_time
=
time
.
time
()
# 记录 decode 开始时间
output_text
+=
token
# 追加 token
sys
.
stdout
.
write
(
str
(
request_id
))
sys
.
stdout
.
write
(
token
)
# 直接输出 token
sys
.
stdout
.
flush
()
# 立即刷新,确保 token 立刻出现在终端
total_tokens
+=
1
# 增加 token 计数
decode_end_time
=
time
.
time
()
# 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason
=
choices
[
0
].
get
(
"finish_reason"
,
None
)
if
finish_reason
:
# print(f"\nRequest {request_id}: Done")
break
# 结束流式处理
except
json
.
JSONDecodeError
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: JSON Decode Error -
{
e
}
"
)
except
IndexError
:
print
(
f
"
\n
Request
{
request_id
}
: List Index Error - choices is empty"
)
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Error parsing stream -
{
e
}
"
)
# 计算 decode 速度
if
decode_start_time
and
decode_end_time
and
total_tokens
>
0
:
decode_time
=
decode_end_time
-
decode_start_time
decode_speed
=
total_tokens
/
decode_time
if
decode_time
>
0
else
0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Exception -
{
e
}
"
)
async
def
main
(
concurrent_requests
,
prompt
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
i
,
prompt
)
for
i
in
range
(
concurrent_requests
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
args
=
parser
.
parse_args
()
SERVER_URL
=
args
.
api_url
if
args
.
prompt_lens
==
1024
:
prompt
=
ktansformer_prompt1024
elif
args
.
prompt_lens
==
2048
:
prompt
=
ktansformer_prompt1024
*
2
asyncio
.
run
(
main
(
args
.
concurrent
,
prompt
))
ktransformers/util/utils.py
View file @
25cee581
...
@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
import
socket
warm_uped
=
False
warm_uped
=
False
def
get_free_ports
(
n
:
int
,
continue_prot
:
list
):
sockets
=
[]
ports
=
[]
for
_
in
range
(
n
):
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
""
,
0
))
port
=
s
.
getsockname
()[
1
]
if
port
in
continue_prot
:
s
.
close
()
continue
ports
.
append
(
port
)
sockets
.
append
(
s
)
for
s
in
sockets
:
s
.
close
()
return
ports
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
if
device
is
None
:
...
@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
...
@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_
prefill_
size
=
16384
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
chunk_start
=
0
chunk_start
=
0
while
chunk_start
<
seq_length
:
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_
prefill_
size
,
seq_length
)
chunk_end
=
min
(
chunk_start
+
chunk_size
,
seq_length
)
if
past_key_values
!=
None
:
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_
prefill_
size
chunk_start
+=
chunk_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
if
generation_config
.
do_sample
:
...
...
merge_tensors/merge_safetensor_gguf.py
View file @
25cee581
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
os
import
os
# insert the path of the project
# insert the path of the project
import
sys
import
sys
sys
.
path
.
insert
(
0
,
"/home/azure/ktransformers"
)
#
sys.path.insert(0, "/home/azure/ktransformers")
import
argparse
import
argparse
import
torch
import
torch
from
ktransformers.util.custom_gguf
import
GGUFLoader
,
translate_name_to_gguf
from
ktransformers.util.custom_gguf
import
GGUFLoader
,
translate_name_to_gguf
...
...
requirements-local_chat.txt
View file @
25cee581
...
@@ -6,4 +6,4 @@ packaging
...
@@ -6,4 +6,4 @@ packaging
cpufeature
cpufeature
protobuf
protobuf
tiktoken
tiktoken
blobfile
blobfile
\ No newline at end of file
setup.py
View file @
25cee581
...
@@ -35,6 +35,8 @@ try:
...
@@ -35,6 +35,8 @@ try:
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
except
ImportError
:
except
ImportError
:
MUSA_HOME
=
None
MUSA_HOME
=
None
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
class
CpuInstructInfo
:
class
CpuInstructInfo
:
CPU_INSTRUCT
=
os
.
getenv
(
"CPU_INSTRUCT"
,
"NATIVE"
)
CPU_INSTRUCT
=
os
.
getenv
(
"CPU_INSTRUCT"
,
"NATIVE"
)
...
@@ -212,7 +214,7 @@ class VersionInfo:
...
@@ -212,7 +214,7 @@ class VersionInfo:
cpu_instruct
=
self
.
get_cpu_instruct
()
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
backend_version
=
""
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
backend_version
=
f
""
backend_version
=
f
"
cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
elif
ROCM_HOME
is
not
None
:
...
@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
...
@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
class
CMakeExtension
(
Extension
):
class
CMakeExtension
(
Extension
):
def
__init__
(
self
,
name
:
str
,
sourcedir
:
str
=
""
)
->
None
:
def
__init__
(
self
,
name
:
str
,
sourcedir
:
str
)
->
None
:
super
().
__init__
(
name
,
sources
=
[])
super
().
__init__
(
name
,
sources
=
[])
self
.
sourcedir
=
os
.
fspath
(
print
(
name
,
sourcedir
)
Path
(
sourcedir
).
resolve
()
/
"ktransformers"
/
"ktransformers_ext"
)
self
.
sourcedir
=
sourcedir
class
CMakeBuild
(
BuildExtension
):
class
CMakeBuild
(
BuildExtension
):
...
@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
...
@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
f
"-DEXAMPLE_VERSION_INFO=
{
self
.
distribution
.
get_version
()
}
"
]
f
"-DEXAMPLE_VERSION_INFO=
{
self
.
distribution
.
get_version
()
}
"
]
if
self
.
compiler
.
compiler_type
!=
"msvc"
:
if
self
.
compiler
.
compiler_type
!=
"msvc"
:
if
not
cmake_generator
or
cmake_generator
==
"Ninja"
:
if
not
cmake_generator
or
cmake_generator
==
"Ninja"
:
try
:
pass
import
ninja
# try:
# import ninja
ninja_executable_path
=
Path
(
ninja
.
BIN_DIR
)
/
"ninja"
cmake_args
+=
[
# ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
"-GNinja"
,
# cmake_args += [
f
"-DCMAKE_MAKE_PROGRAM:FILEPATH=
{
ninja_executable_path
}
"
,
# "-GNinja",
]
# f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
except
ImportError
:
# ]
pass
# except ImportError:
# pass
else
:
else
:
# Single config generators are handled "normally"
# Single config generators are handled "normally"
...
@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
...
@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
build_args
+=
[
f
"--parallel=
{
cpu_count
}
"
]
build_args
+=
[
f
"--parallel=
{
cpu_count
}
"
]
print
(
"CMake args:"
,
cmake_args
)
print
(
"CMake args:"
,
cmake_args
)
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
print
(
"build_temp:"
,
build_temp
)
if
not
build_temp
.
exists
():
if
not
build_temp
.
exists
():
build_temp
.
mkdir
(
parents
=
True
)
build_temp
.
mkdir
(
parents
=
True
)
result
=
subprocess
.
run
(
result
=
subprocess
.
run
(
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
,
text
=
True
)
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard error:"
,
result
.
stderr
)
print
(
"Standard error:"
,
result
.
stderr
)
...
@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
...
@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'
ktransformers
/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'
csrc
/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'
ktransformers
/ktransformers_ext/cuda/binding.cpp'
,
'
csrc
/ktransformers_ext/cuda/binding.cpp'
,
'
ktransformers
/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
'
csrc
/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
],
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
...
@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
...
@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
}
}
)
)
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
SimplePorting
(
cuda_dir_path
=
"
ktransformers
/ktransformers_ext/cuda"
,
mapping_rule
=
{
SimplePorting
(
cuda_dir_path
=
"
csrc
/ktransformers_ext/cuda"
,
mapping_rule
=
{
# Common rules
# Common rules
"at::cuda"
:
"at::musa"
,
"at::cuda"
:
"at::musa"
,
"#include <ATen/cuda/CUDAContext.h>"
:
"#include
\"
torch_musa/csrc/aten/musa/MUSAContext.h
\"
"
,
"#include <ATen/cuda/CUDAContext.h>"
:
"#include
\"
torch_musa/csrc/aten/musa/MUSAContext.h
\"
"
,
...
@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
...
@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
"nv_bfloat16"
:
"mt_bfloat16"
,
"nv_bfloat16"
:
"mt_bfloat16"
,
}).
run
()
}).
run
()
ops_module
=
MUSAExtension
(
'KTransformersOps'
,
[
ops_module
=
MUSAExtension
(
'KTransformersOps'
,
[
'
ktransformers
/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
'
csrc
/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
'
ktransformers
/ktransformers_ext/cuda_musa/binding.cpp'
,
'
csrc
/ktransformers_ext/cuda_musa/binding.cpp'
,
# TODO: Add Marlin support for MUSA.
# TODO: Add Marlin support for MUSA.
# '
ktransformers
/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
# '
csrc
/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
],
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
[
'force_mcc'
],
'cxx'
:
[
'force_mcc'
],
...
@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
...
@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
ops_module
,
CUDAExtension
(
'vLLMMarlin'
,
[
'csrc/custom_marlin/binding.cpp'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
},
)
]
if
with_balance
:
print
(
"using balance_serve"
)
ext_modules
.
append
(
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
)
setup
(
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
name
=
VersionInfo
.
PACKAGE_NAME
,
version
=
VersionInfo
().
get_package_version
(),
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
ext_modules
=
ext_modules
CMakeExtension
(
"cpuinfer_ext"
),
ops_module
,
]
)
)
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment