Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
25cee581
Commit
25cee581
authored
Mar 31, 2025
by
Atream
Browse files
add balance-serve, support concurrence
parent
8d0292aa
Changes
196
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1026 additions
and
62 deletions
+1026
-62
ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py
...ence/sampling/penaltylib/penalizers/repetition_penalty.py
+83
-0
ktransformers/server/balance_serve/inference/sampling/sampler.py
...ormers/server/balance_serve/inference/sampling/sampler.py
+100
-0
ktransformers/server/balance_serve/sched_rpc.py
ktransformers/server/balance_serve/sched_rpc.py
+213
-0
ktransformers/server/balance_serve/settings.py
ktransformers/server/balance_serve/settings.py
+76
-0
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+23
-10
ktransformers/server/main.py
ktransformers/server/main.py
+31
-12
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+4
-2
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-3
ktransformers/server/utils/create_interface.py
ktransformers/server/utils/create_interface.py
+6
-3
ktransformers/tests/mmlu_test_multi.py
ktransformers/tests/mmlu_test_multi.py
+155
-0
ktransformers/tests/test_client.py
ktransformers/tests/test_client.py
+115
-0
ktransformers/tests/test_speed.py
ktransformers/tests/test_speed.py
+146
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+20
-3
merge_tensors/merge_safetensor_gguf.py
merge_tensors/merge_safetensor_gguf.py
+1
-1
requirements-local_chat.txt
requirements-local_chat.txt
+1
-1
setup.py
setup.py
+49
-27
No files found.
ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py
0 → 100644
View file @
25cee581
import
typing
import
torch
from
..orchestrator
import
_BatchedPenalizer
,
_TokenIDs
class
BatchedRepetitionPenalizer
(
_BatchedPenalizer
):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties
:
torch
.
Tensor
=
None
cumulated_repetition_penalties
:
torch
.
Tensor
=
None
def
_is_required
(
self
)
->
bool
:
return
any
(
req
.
sampling_params
.
repetition_penalty
!=
1.0
for
req
in
self
.
orchestrator
.
reqs
()
)
def
_prepare
(
self
):
self
.
cumulated_repetition_penalties
=
(
torch
.
tensor
(
data
=
[
1.0
for
_
in
self
.
orchestrator
.
reqs
()],
dtype
=
torch
.
float32
,
device
=
self
.
orchestrator
.
device
,
)
.
unsqueeze_
(
1
)
.
repeat
(
1
,
self
.
orchestrator
.
vocab_size
)
)
self
.
repetition_penalties
=
(
torch
.
tensor
(
data
=
[
req
.
sampling_params
.
repetition_penalty
for
req
in
self
.
orchestrator
.
reqs
()
],
dtype
=
torch
.
float32
,
device
=
self
.
orchestrator
.
device
,
)
.
unsqueeze_
(
1
)
.
expand_as
(
self
.
cumulated_repetition_penalties
)
)
def
_teardown
(
self
):
del
self
.
repetition_penalties
del
self
.
cumulated_repetition_penalties
self
.
repetition_penalties
=
None
self
.
cumulated_repetition_penalties
=
None
def
_cumulate_input_tokens
(
self
,
input_ids
:
_TokenIDs
):
mask
=
input_ids
.
occurrence_count
()
>
0
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_cumulate_output_tokens
(
self
,
output_ids
:
_TokenIDs
):
mask
=
output_ids
.
occurrence_count
()
>
0
self
.
cumulated_repetition_penalties
[
mask
]
=
self
.
repetition_penalties
[
mask
]
def
_apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
where
(
logits
>
0
,
logits
/
self
.
cumulated_repetition_penalties
,
logits
*
self
.
cumulated_repetition_penalties
,
)
def
_filter
(
self
,
indices_to_keep
:
typing
.
List
[
int
],
indices_tensor_to_keep
:
torch
.
Tensor
):
self
.
repetition_penalties
=
self
.
repetition_penalties
[
indices_tensor_to_keep
]
self
.
cumulated_repetition_penalties
=
self
.
cumulated_repetition_penalties
[
indices_tensor_to_keep
]
def
_merge
(
self
,
their
:
"BatchedRepetitionPenalizer"
):
self
.
repetition_penalties
=
torch
.
cat
(
[
self
.
repetition_penalties
,
their
.
repetition_penalties
],
dim
=
0
)
self
.
cumulated_repetition_penalties
=
torch
.
cat
(
[
self
.
cumulated_repetition_penalties
,
their
.
cumulated_repetition_penalties
],
dim
=
0
,
)
ktransformers/server/balance_serve/inference/sampling/sampler.py
0 → 100644
View file @
25cee581
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import
logging
import
torch
from
torch
import
nn
from
transformers
import
GenerationConfig
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_probs
,
top_k_top_p_sampling_from_logits
,
top_p_renorm_probs
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
SamplingOptions
():
# Batched sampling params
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
# All requests use greedy sampling
is_all_greedy
:
bool
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
def
__init__
(
self
,
bsz
=
1
,
device
=
torch
.
device
(
'cuda'
),
pretrained_config
:
GenerationConfig
=
None
,
temperatures
:
torch
.
Tensor
=
None
,
top_ps
:
torch
.
Tensor
=
None
):
if
pretrained_config
is
None
and
temperatures
is
None
:
self
.
temperatures
=
torch
.
full
((
bsz
,
1
),
0
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ps
=
torch
.
ones
((
bsz
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ks
=
torch
.
ones
((
bsz
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
self
.
need_min_p_sampling
=
False
self
.
is_all_greedy
=
True
else
:
if
temperatures
is
not
None
:
self
.
temperatures
=
temperatures
.
unsqueeze
(
-
1
)
else
:
self
.
temperatures
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
temperature
,
device
=
device
,
dtype
=
torch
.
float32
)
if
top_ps
is
not
None
:
self
.
top_ps
=
top_ps
.
unsqueeze
(
-
1
)
else
:
self
.
top_ps
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
top_p
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
top_ks
=
torch
.
full
((
bsz
,
1
),
pretrained_config
.
top_k
,
device
=
device
,
dtype
=
torch
.
float32
)
self
.
need_min_p_sampling
=
False
self
.
is_all_greedy
=
False
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_config
:
SamplingOptions
=
None
,
):
if
sampling_config
==
None
:
sampling_config
=
SamplingOptions
()
logits
=
logits
.
contiguous
()
origin_logits
=
logits
.
clone
()
if
sampling_config
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
probs
=
logits
batch_next_token_ids
=
torch
.
argmax
(
logits
,
-
1
)
else
:
# Post process logits
logits
.
div_
(
sampling_config
.
temperatures
)
max_top_k_round
,
batch_size
=
32
,
logits
.
shape
[
0
]
if
sampling_config
.
need_min_p_sampling
:
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
None
del
logits
probs
=
top_k_renorm_probs
(
probs
,
sampling_config
.
top_ks
)
probs
=
top_p_renorm_probs
(
probs
,
sampling_config
.
top_ps
)
batch_next_token_ids
=
min_p_sampling_from_probs
(
probs
,
sampling_config
.
min_ps
)
temperature_0_idx
=
torch
.
where
(
sampling_config
.
temperatures
==
0
)[
0
]
batch_next_token_ids
[
temperature_0_idx
]
=
torch
.
argmax
(
origin_logits
[
temperature_0_idx
],
-
1
).
to
(
torch
.
int32
)
else
:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs
=
logits
batch_next_token_ids
=
top_k_top_p_sampling_from_logits
(
logits
,
sampling_config
.
top_ks
,
sampling_config
.
top_ps
,
filter_apply_order
=
"joint"
,
)
temperature_0_idx
=
torch
.
where
(
sampling_config
.
temperatures
==
0
)[
0
]
batch_next_token_ids
[
temperature_0_idx
]
=
torch
.
argmax
(
origin_logits
[
temperature_0_idx
],
-
1
).
to
(
torch
.
int32
)
return
batch_next_token_ids
.
to
(
torch
.
int32
),
probs
\ No newline at end of file
ktransformers/server/balance_serve/sched_rpc.py
0 → 100644
View file @
25cee581
from
datetime
import
datetime
import
os
from
typing
import
Optional
import
zmq
import
pickle
import
threading
import
torch.multiprocessing
as
mp
import
sys
current_file_path
=
os
.
path
.
abspath
(
__file__
)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import
pickle
import
argparse
from
ktransformers.server.balance_serve.settings
import
sched_ext
,
create_sched_settings
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
print
(
'set start method'
)
mp
.
set_start_method
(
'spawn'
)
else
:
print
(
f
'start method already set to
{
mp
.
get_start_method
(
allow_none
=
True
)
}
'
)
class
SchedulerServer
:
def
__init__
(
self
,
settings
,
main_args
):
# 创建 Scheduler 实例并初始化
self
.
sched
=
sched_ext
.
create_scheduler
(
settings
)
# 初始化 ZeroMQ 上下文和套接字
self
.
context
=
zmq
.
Context
()
self
.
frontend
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
print
(
f
"sched zmq rpc server on port
{
main_args
.
sched_port
}
"
)
self
.
frontend
.
bind
(
f
"tcp://*:
{
main_args
.
sched_port
}
"
)
# 创建内部的 DEALER 套接字,用于与工作线程通信
self
.
backend
=
self
.
context
.
socket
(
zmq
.
DEALER
)
self
.
backend
.
bind
(
"inproc://backend"
)
# 启动调度器
def
run_scheduler
(
self
):
self
.
sched
.
run
()
# 停止调度器
def
stop_scheduler
(
self
):
self
.
sched
.
stop
()
# 处理客户端请求
def
start_proxy
(
self
):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq
.
proxy
(
self
.
frontend
,
self
.
backend
)
# 工作线程处理请求
def
worker_routine
(
self
):
worker
=
self
.
context
.
socket
(
zmq
.
REP
)
worker
.
connect
(
"inproc://backend"
)
while
True
:
try
:
# 接收客户端请求
message
=
worker
.
recv
()
data
=
pickle
.
loads
(
message
)
method
=
data
.
get
(
'method'
)
params
=
data
.
get
(
'params'
,
{})
# print(f"Received request: {method}")
if
method
==
'add_query'
:
query_add
=
params
.
get
(
'query'
)
# 直接是一个 QueryAdd 对象
# 添加查询
query_id
=
self
.
sched
.
add_query
(
query_add
)
# 发送响应
response
=
{
'status'
:
'ok'
,
'query_id'
:
query_id
}
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'cancel_query'
:
query_id
=
params
.
get
(
'query_id'
)
# 假设您的 Scheduler 类实现了 cancel 方法
self
.
sched
.
cancel
(
query_id
)
response
=
{
'status'
:
'ok'
}
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'update_last_batch'
:
updates
=
params
.
get
(
'updates'
)
# 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo
=
self
.
sched
.
update_last_batch
(
updates
)
# 直接发送 batch_todo 对象
response
=
{
'status'
:
'ok'
,
'batch_todo'
:
batch_todo
}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker
.
send
(
pickle
.
dumps
(
response
))
elif
method
==
'get_inference_context'
:
inference_context
=
self
.
sched
.
get_inference_context
()
data
=
{
"k_cache"
:
inference_context
.
k_cache
,
"v_cache"
:
inference_context
.
v_cache
}
print
(
f
"Serializing KVCache"
)
data
[
"k_cache"
]
=
[
mp
.
reductions
.
reduce_tensor
(
t
)
for
t
in
data
[
'k_cache'
]]
data
[
"v_cache"
]
=
[
mp
.
reductions
.
reduce_tensor
(
t
)
for
t
in
data
[
'v_cache'
]]
# print(data)
response
=
{
'status'
:
'ok'
,
'inference_context'
:
data
}
worker
.
send
(
pickle
.
dumps
(
response
))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else
:
# 未知方法
response
=
{
'status'
:
'error'
,
'message'
:
'Unknown method'
}
worker
.
send
(
pickle
.
dumps
(
response
))
except
Exception
as
e
:
# 处理异常并发送错误响应
response
=
{
'status'
:
'error'
,
'message'
:
str
(
e
)}
worker
.
send
(
pickle
.
dumps
(
response
))
# 启动 RPC 服务
def
start_rpc_service
(
self
):
try
:
print
(
"Scheduler RPC service is running..."
)
# 在单独的线程中运行调度器
threading
.
Thread
(
target
=
self
.
run_scheduler
,
daemon
=
True
).
start
()
# 启动工作线程
for
_
in
range
(
10
):
# 根据需要调整线程数
threading
.
Thread
(
target
=
self
.
worker_routine
,
daemon
=
True
).
start
()
# 启动代理,开始监听请求
self
.
start_proxy
()
except
KeyboardInterrupt
:
print
(
"Shutting down scheduler RPC service..."
)
self
.
stop_rpc_service
()
# 停止 RPC 服务
def
stop_rpc_service
(
self
):
self
.
stop_scheduler
()
self
.
frontend
.
close
()
self
.
backend
.
close
()
self
.
context
.
term
()
def
start_server
(
settings
,
main_args
):
server
=
SchedulerServer
(
settings
,
main_args
)
server
.
start_rpc_service
()
# Add async client for webserver
class
SchedulerClient
:
def
__init__
(
self
,
sched_port
):
address
=
f
'tcp://localhost:
{
sched_port
}
'
self
.
address
=
address
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
REQ
)
self
.
socket
.
connect
(
self
.
address
)
print
(
f
"Connected to server at
{
self
.
address
}
"
)
def
__del__
(
self
):
self
.
socket
.
close
()
self
.
context
.
term
()
def
send_request
(
self
,
method
,
params
=
None
):
if
params
is
None
:
params
=
{}
request
=
{
'method'
:
method
,
'params'
:
params
}
# print(f'send request {request}')
self
.
socket
.
send
(
pickle
.
dumps
(
request
))
response
=
self
.
socket
.
recv
()
# print(response)
response
=
pickle
.
loads
(
response
)
if
response
.
get
(
'status'
)
==
'ok'
:
return
response
else
:
raise
Exception
(
f
"Error from server:
{
response
.
get
(
'message'
)
}
"
)
def
add_query
(
self
,
query
):
response
=
self
.
send_request
(
'add_query'
,
{
'query'
:
query
})
return
response
.
get
(
'query_id'
)
def
cancel_query
(
self
,
query_id
):
self
.
send_request
(
'cancel_query'
,
{
'query_id'
:
query_id
})
def
update_last_batch
(
self
,
updates
):
response
=
self
.
send_request
(
'update_last_batch'
,
{
'updates'
:
updates
})
# print(f"update_last_batch response {response}")
return
response
.
get
(
'batch_todo'
)
def
rebuild_inferece_context
(
self
,
response
):
data
=
response
.
get
(
'inference_context'
)
inference_context
=
sched_ext
.
InferenceContext
()
print
(
'Rebuilding kvcache'
)
inference_context
.
k_cache
=
[
fn
(
*
args
)
for
fn
,
args
in
data
[
'k_cache'
]]
inference_context
.
v_cache
=
[
fn
(
*
args
)
for
fn
,
args
in
data
[
'v_cache'
]]
return
inference_context
def
get_inference_context_raw
(
self
):
response
=
self
.
send_request
(
'get_inference_context'
)
return
response
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
"rb"
)
as
f
:
main_args
=
pickle
.
load
(
f
)
settings
=
create_sched_settings
(
main_args
)
start_server
(
settings
,
main_args
)
ktransformers/server/balance_serve/settings.py
0 → 100644
View file @
25cee581
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import
sys
,
os
import
yaml
,
json
from
time
import
sleep
current_dir
=
os
.
path
.
dirname
(
__file__
)
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
# sys.path.insert(0, sched_path)
import
sched_ext
from
transformers
import
AutoConfig
def
create_sched_settings
(
args
):
default_sample_options
=
sched_ext
.
SampleOptions
()
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
input_model_settings
=
sched_ext
.
ModelSettings
()
input_model_settings
.
model_path
=
args
.
model_dir
input_model_settings
.
params_count
=
int
(
0
)
model_config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
input_model_settings
.
layer_count
=
model_config
.
num_hidden_layers
input_model_settings
.
num_k_heads
=
1
# model_config["num_key_value_heads"]
input_model_settings
.
k_head_dim
=
576
input_model_settings
.
bytes_per_params
=
2
input_model_settings
.
bytes_per_kv_cache_element
=
2
settings
=
sched_ext
.
Settings
()
settings
.
model_name
=
model_name
settings
.
quant_type
=
"BF16"
settings
.
model_settings
=
input_model_settings
settings
.
page_size
=
args
.
page_size
settings
.
gpu_device_count
=
1
# tp
settings
.
gpu_device_id
=
[
i
for
i
in
range
(
settings
.
gpu_device_count
)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings
.
gpu_memory_size
=
args
.
gpu_memory_size
settings
.
memory_utilization_percentage
=
args
.
utilization_percentage
max_batch_size
=
args
.
max_batch_size
chunk_size
=
args
.
chunk_size
max_decode_batch_size
=
max_batch_size
-
2
settings
.
max_batch_size
=
max_batch_size
settings
.
recommended_chunk_prefill_token_count
=
(
chunk_size
-
max_decode_batch_size
)
//
2
settings
.
sample_options
=
default_sample_options
settings
.
sched_metrics_port
=
args
.
sched_metrics_port
settings
.
gpu_only
=
args
.
memory_gpu_only
settings
.
use_self_defined_head_dim
=
True
settings
.
self_defined_head_dim
=
576
settings
.
full_kv_cache_on_each_gpu
=
True
settings
.
k_cache_on
=
True
settings
.
v_cache_on
=
False
settings
.
kvc2_root_path
=
'/mnt/data/persist-kvc'
settings
.
kvc2_config_path
=
os
.
path
.
join
(
current_dir
,
".."
,
".."
,
"configs"
)
print
(
os
.
path
.
join
(
current_dir
,
".."
,
".."
,
"configs"
))
settings
.
memory_pool_size_GB
=
args
.
cpu_memory_size_GB
settings
.
evict_count
=
40
settings
.
kvc2_metrics_port
=
args
.
kvc2_metrics_port
settings
.
load_from_disk
=
False
settings
.
save_to_disk
=
True
settings
.
strategy_name
=
args
.
sched_strategy
settings
.
auto_derive
()
return
settings
ktransformers/server/config/config.py
View file @
25cee581
...
...
@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import
os
import
shutil
import
yaml
import
psutil
from
ktransformers.server.config.singleton
import
Singleton
from
typing
import
Optional
...
...
@@ -60,7 +61,7 @@ class Config(metaclass=Singleton):
self
.
user_path
:
str
=
os
.
path
.
expanduser
(
"~"
)
self
.
localstore_path
:
str
=
os
.
path
.
join
(
self
.
user_path
,
".ktransformers"
)
# log configs
self
.
log_dir
=
os
.
path
.
join
(
self
.
base_path
,
Config
.
to_path
(
cfg
[
"log"
][
"dir"
])
)
self
.
log_dir
=
os
.
path
.
join
(
self
.
localstore_path
,
cfg
[
"log"
][
"dir"
])
self
.
log_file
=
cfg
[
"log"
][
"file"
]
self
.
log_level
=
cfg
[
"log"
][
"level"
]
self
.
backup_count
=
cfg
[
"log"
][
"backup_count"
]
...
...
@@ -74,7 +75,7 @@ class Config(metaclass=Singleton):
# db configs
self
.
db_configs
:
dict
=
cfg
.
get
(
"db"
,
{})
self
.
db_type
=
self
.
db_configs
.
get
(
"type"
,
""
)
self
.
db_host
=
os
.
path
.
join
(
self
.
base
_path
,
self
.
db_configs
.
get
(
"host"
,
""
))
self
.
db_host
=
Config
.
to
_path
(
self
.
db_configs
.
get
(
"host"
,
""
))
self
.
db_port
=
self
.
db_configs
.
get
(
"port"
,
""
)
self
.
db_name
=
self
.
db_configs
.
get
(
"database"
,
""
)
self
.
db_pool_size
=
self
.
db_configs
.
get
(
"pool_size"
)
...
...
@@ -101,11 +102,6 @@ class Config(metaclass=Singleton):
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"optimize_config_path"
,
None
)
self
.
paged
=
self
.
model
.
get
(
"paged"
,
True
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
...
...
@@ -138,7 +134,6 @@ class Config(metaclass=Singleton):
self
.
repetition_penalty
=
self
.
model
.
get
(
"repetition_penalty"
,
1.01
)
self
.
frequency_penalty
=
self
.
model
.
get
(
"frequency_penalty"
,
0.0
)
self
.
presence_penalty
=
self
.
model
.
get
(
"presence_penalty"
,
0.0
)
self
.
max_response_tokens
=
self
.
model
.
get
(
"max_response_tokens"
,
300
)
self
.
response_chunk
=
self
.
model
.
get
(
"response_chunk"
,
250
)
self
.
no_code_formatting
=
self
.
model
.
get
(
"no_code_formatting"
,
False
)
self
.
cache_8bit
=
self
.
model
.
get
(
"cache_8bit"
,
False
)
...
...
@@ -155,8 +150,9 @@ class Config(metaclass=Singleton):
self
.
web_cross_domain
:
bool
=
self
.
web
.
get
(
"open_cross_domain"
,
True
)
self
.
mount_web
:
bool
=
self
.
web
.
get
(
"mount"
,
False
)
# ext
self
.
ext
:
dict
=
cfg
.
get
(
"ext"
,
{})
self
.
cpu_infer
=
self
.
ext
.
get
(
"cpu_infer"
,
10
)
self
.
cpu_infer
=
psutil
.
cpu_count
(
logical
=
False
)
-
3
# file config
self
.
local_store_configs
:
dict
=
cfg
.
get
(
"local_store"
,
{})
...
...
@@ -169,7 +165,6 @@ class Config(metaclass=Singleton):
# long context config
self
.
long_context_config
:
dict
=
cfg
.
get
(
"long_context"
,
{})
self
.
chunk_size
=
self
.
long_context_config
.
get
(
"chunk_size"
,
4096
)
self
.
max_seq_len
=
self
.
long_context_config
.
get
(
"max_seq_len"
,
32000
)
self
.
block_size
=
self
.
long_context_config
.
get
(
"block_size"
,
128
)
self
.
local_windows_len
=
self
.
long_context_config
.
get
(
"local_windows_len"
,
4096
)
...
...
@@ -187,3 +182,21 @@ class Config(metaclass=Singleton):
# local chat
self
.
local_chat_config
:
dict
=
cfg
.
get
(
"local_chat"
,
{})
self
.
prompt_file
=
self
.
local_chat_config
.
get
(
"prompt_file"
,
None
)
# asyncserver
self
.
sched_strategy
=
cfg
[
'async_server'
][
'sched_strategy'
]
self
.
sched_port
=
cfg
[
'async_server'
][
'sched_port'
]
self
.
sched_metrics_port
=
cfg
[
'async_server'
][
'sched_metrics_port'
]
self
.
kvc2_metrics_port
=
cfg
[
'async_server'
][
'kvc2_metrics_port'
]
self
.
max_batch_size
=
cfg
[
'async_server'
][
'max_batch_size'
]
self
.
page_size
=
cfg
[
'attn'
][
'page_size'
]
self
.
chunk_size
=
cfg
[
'attn'
][
'chunk_size'
]
self
.
memory_gpu_only
=
cfg
[
'kvc2'
][
'gpu_only'
]
self
.
cache_lens
=
((
self
.
cache_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
)
*
self
.
page_size
self
.
gpu_memory_size
=
2
*
576
*
61
*
self
.
cache_lens
self
.
utilization_percentage
=
1.0
#cfg['kvc2']['utilization_percentage']
self
.
cpu_memory_size_GB
=
cfg
[
'kvc2'
][
'cpu_memory_size_GB'
]
# only support 2 prefill task
self
.
max_prefill_batch_size
=
2
self
.
max_decode_batch_size
=
self
.
max_batch_size
-
self
.
max_prefill_batch_size
ktransformers/server/main.py
View file @
25cee581
...
...
@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import
uvicorn.logging
import
uvicorn
import
sys
import
atexit
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
sys
.
path
.
insert
(
0
,
project_dir
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.args
import
ArgumentParser
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.utils.create_interface
import
create_interface
from
ktransformers.server.backend.args
import
default_args
from
ktransformers.server.utils.create_interface
import
create_interface
,
GlobalInterface
from
fastapi.openapi.utils
import
get_openapi
from
fastapi
import
FastAPI
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.api
import
router
,
post_db_creation_operations
from
ktransformers.server.utils.sql_utils
import
Base
,
SQLUtil
from
ktransformers.server.config.log
import
logger
import
subprocess
import
tempfile
def
mount_app_routes
(
mount_app
:
FastAPI
):
sql_util
=
SQLUtil
()
...
...
@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def
create_app
():
cfg
=
Config
()
app
=
FastAPI
()
if
(
hasattr
(
GlobalInterface
.
interface
,
"lifespan"
)):
app
=
FastAPI
(
lifespan
=
GlobalInterface
.
interface
.
lifespan
)
else
:
app
=
FastAPI
()
if
Config
().
web_cross_domain
:
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -108,11 +107,32 @@ def main():
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
args
=
arg_parser
.
parse_args
()
if
args
.
backend_type
==
"balance_serve"
:
import
pickle
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
app
=
create_app
()
custom_openapi
(
app
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
run_api
(
app
=
app
,
host
=
args
.
host
,
...
...
@@ -121,6 +141,5 @@ def main():
ssl_certfile
=
args
.
ssl_certfile
,
)
if
__name__
==
"__main__"
:
main
()
ktransformers/server/requirements.txt
View file @
25cee581
torch >= 2.3.0
,<=2.3.1
torch >= 2.3.0
transformers == 4.43.2
fastapi >= 0.111.0
langchain >= 0.2.0
...
...
@@ -11,4 +11,6 @@ build
ninja
wheel
colorlog
fire
\ No newline at end of file
fire
zmq
psutil
\ No newline at end of file
ktransformers/server/schemas/endpoints/chat.py
View file @
25cee581
...
...
@@ -2,7 +2,7 @@ from typing import List, Optional
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
Field
from
ktransformers.server.schemas.base
import
Object
...
...
@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages
:
List
[
Message
]
model
:
str
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
temperature
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/utils/create_interface.py
View file @
25cee581
...
...
@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
def
create_interface
(
config
:
Config
,
default_args
:
ConfigArgs
):
if
config
.
backend_type
==
'transformers'
:
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
as
BackendInterface
...
...
@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
as
BackendInterface
elif
config
.
backend_type
==
'ktransformers'
:
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
as
BackendInterface
elif
config
.
backend_type
==
'balance_serve'
:
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeInterface
as
BackendInterface
else
:
raise
NotImplementedError
(
f
'
{
config
.
backend_type
}
not implemented'
)
GlobalInterface
.
interface
=
BackendInterface
(
default_args
)
...
...
@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class
GlobalContextManager
:
context_manager
:
ThreadContextManager
class
GlobalInterface
:
interface
:
TransformersInterface
|
KTransformersInterface
|
ExllamaInterface
interface
:
TransformersInterface
|
KTransformersInterface
|
ExllamaInterface
def
get_thread_context_manager
()
->
Thread
ContextManager
:
def
get_thread_context_manager
()
->
Global
ContextManager
:
return
GlobalContextManager
.
context_manager
def
get_interface
()
->
TransformersInterface
|
KTransformersInterface
|
Exllama
Interface
:
def
get_interface
()
->
Global
Interface
:
return
GlobalInterface
.
interface
\ No newline at end of file
ktransformers/tests/mmlu_test_multi.py
0 → 100644
View file @
25cee581
import
argparse
import
random
import
time
import
json
import
requests
import
pandas
as
pd
from
datasets
import
load_dataset
import
os
import
concurrent.futures
import
threading
os
.
environ
[
'HF_ENDPOINT'
]
=
'https://hf-mirror.com'
os
.
environ
[
'https_proxy'
]
=
''
os
.
environ
[
'http_proxy'
]
=
''
hint
=
'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class
DataEvaluator
:
def
__init__
(
self
):
self
.
data
=
[]
def
load_data
(
self
,
file_path
):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds
=
load_dataset
(
file_path
,
"all"
)
df
=
pd
.
DataFrame
(
ds
[
'test'
])
for
_
,
row
in
df
.
iterrows
():
self
.
data
.
append
(
row
.
to_dict
())
def
get_prompt
(
self
,
record
):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str
=
"
\n
"
.
join
([
f
"
{
chr
(
65
+
i
)
}
.
{
opt
}
"
for
i
,
opt
in
enumerate
(
record
[
'choices'
])])
prompt
=
hint
+
"
\n
Question: "
+
record
[
'question'
]
+
"
\n
"
+
options_str
+
"
\n
Answer: '"
return
prompt
def
post_processing
(
self
,
text
):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text
=
text
.
lstrip
(
'
\n
'
).
split
(
'
\n
'
)[
-
1
]
return
text
[
-
1
:]
def
score
(
self
,
pred
,
answer
):
"""
对比预测答案和正确答案,返回得分
"""
if
pred
==
answer
:
return
1
return
0
def
generate_text
(
api_url
,
question
,
model_name
,
stream
=
False
):
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
,
'Authorization'
:
'Bearer '
# 如有需要,请填入 API Key
}
data
=
{
"messages"
:
[{
"content"
:
question
,
"role"
:
"user"
}],
"model"
:
model_name
,
"stream"
:
stream
,
}
print
(
"POST data:"
,
data
)
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
data
,
timeout
=
5000000
)
if
response
.
status_code
==
200
:
result
=
response
.
json
()
return
result
.
get
(
'choices'
,
[{}])[
0
].
get
(
'message'
,
{}).
get
(
'content'
,
''
).
strip
()
else
:
print
(
f
"API Request failed with status code
{
response
.
status_code
}
"
)
return
None
def
main
(
concurrent_requests
,
data_evaluator
:
DataEvaluator
,
result_file
,
log_file
,
api_url
,
model_name
):
start_total_time
=
time
.
time
()
total_score
=
0
results
=
[]
file_lock
=
threading
.
Lock
()
# 打乱数据顺序,并选择需要测试的实例数
random
.
seed
(
42
)
random
.
shuffle
(
data_evaluator
.
data
)
data_subset
=
data_evaluator
.
data
[:
min
(
concurrent_requests
,
len
(
data_evaluator
.
data
))]
batch_size
=
10
# 每批次最多 10 个实例
def
worker
(
index
,
data_item
):
nonlocal
total_score
question
=
data_evaluator
.
get_prompt
(
data_item
)
start_time
=
time
.
time
()
try
:
prediction
=
generate_text
(
api_url
,
question
,
model_name
)
if
prediction
is
None
:
raise
Exception
(
f
"Failed to get prediction for question:
{
question
}
"
)
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer
=
chr
(
data_item
[
'answer'
]
+
65
)
processed_prediction
=
data_evaluator
.
post_processing
(
prediction
)
score
=
data_evaluator
.
score
(
processed_prediction
,
answer
)
elapsed_time
=
time
.
time
()
-
start_time
result_data
=
{
"question_id"
:
index
,
"answer"
:
answer
,
"prediction"
:
processed_prediction
,
"real_prediction"
:
prediction
,
"score"
:
score
,
"time"
:
elapsed_time
}
# 写入结果时加锁保证线程安全
with
file_lock
:
with
open
(
result_file
,
'a'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
result_data
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
f
.
write
(
"
\n
"
)
return
result_data
except
Exception
as
e
:
print
(
f
"Error processing request
{
index
}
:
{
e
}
"
)
return
None
# 按批次处理,每批最多 10 个任务
for
batch_start
in
range
(
0
,
len
(
data_subset
),
batch_size
):
batch
=
data_subset
[
batch_start
:
batch_start
+
batch_size
]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
batch_size
)
as
executor
:
futures
=
[
executor
.
submit
(
worker
,
batch_start
+
j
,
data_item
)
for
j
,
data_item
in
enumerate
(
batch
)]
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
res
=
future
.
result
()
if
res
is
not
None
:
results
.
append
(
res
)
total_score
+=
res
[
'score'
]
total_time
=
time
.
time
()
-
start_total_time
throughput
=
len
(
data_subset
)
/
total_time
if
total_time
>
0
else
0
with
open
(
log_file
,
'a'
,
encoding
=
'utf-8'
)
as
log_f
:
log_f
.
write
(
f
"Total Time:
{
total_time
:.
2
f
}
seconds
\n
"
)
log_f
.
write
(
f
"Throughput:
{
throughput
:.
2
f
}
requests per second
\n
"
)
average_score
=
total_score
/
len
(
data_subset
)
if
data_subset
else
0
log_f
.
write
(
f
"Average Score:
{
average_score
}
\n
"
)
log_f
.
write
(
'-'
*
40
+
'
\n
'
)
print
(
f
"Results saved to
{
result_file
}
"
)
print
(
f
"Log saved to
{
log_file
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1000
,
help
=
"需要测试的实例总数"
)
parser
.
add_argument
(
"--file"
,
type
=
str
,
default
=
"cais/mmlu"
,
help
=
"数据文件路径"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_result_silicon.json"
,
help
=
"结果文件保存路径"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_result_silicon.log"
,
help
=
"日志文件保存路径"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"模型名称或路径"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10006/v1/chat/completions"
,
help
=
"API URL"
)
args
=
parser
.
parse_args
()
data_evaluator
=
DataEvaluator
()
data_evaluator
.
load_data
(
args
.
file
)
main
(
args
.
concurrent
,
data_evaluator
,
args
.
result
,
args
.
log
,
args
.
api_url
,
args
.
model
)
ktransformers/tests/test_client.py
0 → 100644
View file @
25cee581
import
asyncio
import
json
import
sys
import
aiohttp
import
random
import
argparse
import
yaml
import
os
import
time
from
time
import
sleep
decodesz
=
128
# Server URL (replace with your server URL)
SERVER_URL
=
"http://localhost:10002/v1/chat/completions"
bf_list
=
[
1
]
decodesz_list
=
[
128
]
prompt_list
=
[
'请你介绍下秦始皇'
,
'3.9 和 3.11 哪个大'
,
'抗衰老有何妙招'
,
'给我讲个故事'
]
async
def
fetch_event_stream
(
session
,
request_id
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
request_id
]}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
}
async
with
session
.
post
(
SERVER_URL
,
json
=
payload
,
headers
=
headers
,
timeout
=
50000
)
as
response
:
print
(
f
"Request
{
request_id
}
: Connected, status
{
response
.
status
}
"
)
if
response
.
status
!=
200
:
print
(
f
"Request
{
request_id
}
: Error, status
{
response
.
status
}
"
)
return
output_text
=
""
# 存储当前 response 的所有 token
total_tokens
=
0
# 统计总 tokens 数
decode_start_time
=
None
# 记录 decode 阶段开始时间
decode_end_time
=
None
# 记录 decode 结束时间
async
for
line
in
response
.
content
:
try
:
decoded_line
=
line
.
decode
(
"utf-8"
).
strip
()
# 过滤空行
if
not
decoded_line
or
not
decoded_line
.
startswith
(
"data: "
):
continue
decoded_line
=
decoded_line
[
6
:].
strip
()
# 去掉 `data: `
# 确保 JSON 数据是合法的
if
not
decoded_line
:
continue
response_data
=
json
.
loads
(
decoded_line
)
# 解析 JSON
# 确保 choices 存在
choices
=
response_data
.
get
(
"choices"
,
[])
if
not
choices
:
continue
delta
=
choices
[
0
].
get
(
"delta"
,
{})
token
=
delta
.
get
(
"content"
,
""
)
if
token
:
if
decode_start_time
is
None
:
decode_start_time
=
time
.
time
()
# 记录 decode 开始时间
output_text
+=
token
# 追加 token
sys
.
stdout
.
write
(
token
)
# 直接输出 token
sys
.
stdout
.
flush
()
# 立即刷新,确保 token 立刻出现在终端
total_tokens
+=
1
# 增加 token 计数
decode_end_time
=
time
.
time
()
# 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason
=
choices
[
0
].
get
(
"finish_reason"
,
None
)
if
finish_reason
:
# print(f"\nRequest {request_id}: Done")
break
# 结束流式处理
except
json
.
JSONDecodeError
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: JSON Decode Error -
{
e
}
"
)
except
IndexError
:
print
(
f
"
\n
Request
{
request_id
}
: List Index Error - choices is empty"
)
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Error parsing stream -
{
e
}
"
)
# 计算 decode 速度
if
decode_start_time
and
decode_end_time
and
total_tokens
>
0
:
decode_time
=
decode_end_time
-
decode_start_time
decode_speed
=
total_tokens
/
decode_time
if
decode_time
>
0
else
0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Exception -
{
e
}
"
)
async
def
main
(
prompt_id
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--question_id"
,
type
=
int
,
default
=
0
,
required
=
False
)
args
=
parser
.
parse_args
()
output_file
=
"ktransformer_test_results.txt"
asyncio
.
run
(
main
(
args
.
question_id
))
ktransformers/tests/test_speed.py
0 → 100644
View file @
25cee581
import
asyncio
import
json
import
sys
import
aiohttp
import
random
import
argparse
import
yaml
import
os
import
time
from
time
import
sleep
decodesz
=
128
# Server URL (replace with your server URL)
decodesz_list
=
[
128
]
ktansformer_prompt1024
=
"""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。想。
请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字"""
async
def
fetch_event_stream
(
session
,
request_id
,
prompt
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt
}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
'Content-Type'
:
'application/json'
}
async
with
session
.
post
(
SERVER_URL
,
json
=
payload
,
headers
=
headers
,
timeout
=
500000
)
as
response
:
print
(
f
"Request
{
request_id
}
: Connected, status
{
response
.
status
}
"
)
if
response
.
status
!=
200
:
print
(
f
"Request
{
request_id
}
: Error, status
{
response
.
status
}
"
)
return
output_text
=
""
# 存储当前 response 的所有 token
total_tokens
=
0
# 统计总 tokens 数
decode_start_time
=
None
# 记录 decode 阶段开始时间
decode_end_time
=
None
# 记录 decode 结束时间
async
for
line
in
response
.
content
:
try
:
decoded_line
=
line
.
decode
(
"utf-8"
).
strip
()
# 过滤空行
if
not
decoded_line
or
not
decoded_line
.
startswith
(
"data: "
):
continue
decoded_line
=
decoded_line
[
6
:].
strip
()
# 去掉 `data: `
# 确保 JSON 数据是合法的
if
not
decoded_line
:
continue
response_data
=
json
.
loads
(
decoded_line
)
# 解析 JSON
# 确保 choices 存在
choices
=
response_data
.
get
(
"choices"
,
[])
if
not
choices
:
continue
delta
=
choices
[
0
].
get
(
"delta"
,
{})
token
=
delta
.
get
(
"content"
,
""
)
if
token
:
if
decode_start_time
is
None
:
decode_start_time
=
time
.
time
()
# 记录 decode 开始时间
output_text
+=
token
# 追加 token
sys
.
stdout
.
write
(
str
(
request_id
))
sys
.
stdout
.
write
(
token
)
# 直接输出 token
sys
.
stdout
.
flush
()
# 立即刷新,确保 token 立刻出现在终端
total_tokens
+=
1
# 增加 token 计数
decode_end_time
=
time
.
time
()
# 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason
=
choices
[
0
].
get
(
"finish_reason"
,
None
)
if
finish_reason
:
# print(f"\nRequest {request_id}: Done")
break
# 结束流式处理
except
json
.
JSONDecodeError
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: JSON Decode Error -
{
e
}
"
)
except
IndexError
:
print
(
f
"
\n
Request
{
request_id
}
: List Index Error - choices is empty"
)
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Error parsing stream -
{
e
}
"
)
# 计算 decode 速度
if
decode_start_time
and
decode_end_time
and
total_tokens
>
0
:
decode_time
=
decode_end_time
-
decode_start_time
decode_speed
=
total_tokens
/
decode_time
if
decode_time
>
0
else
0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except
Exception
as
e
:
print
(
f
"
\n
Request
{
request_id
}
: Exception -
{
e
}
"
)
async
def
main
(
concurrent_requests
,
prompt
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
i
,
prompt
)
for
i
in
range
(
concurrent_requests
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
args
=
parser
.
parse_args
()
SERVER_URL
=
args
.
api_url
if
args
.
prompt_lens
==
1024
:
prompt
=
ktansformer_prompt1024
elif
args
.
prompt_lens
==
2048
:
prompt
=
ktansformer_prompt1024
*
2
asyncio
.
run
(
main
(
args
.
concurrent
,
prompt
))
ktransformers/util/utils.py
View file @
25cee581
...
...
@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
import
socket
warm_uped
=
False
def
get_free_ports
(
n
:
int
,
continue_prot
:
list
):
sockets
=
[]
ports
=
[]
for
_
in
range
(
n
):
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
""
,
0
))
port
=
s
.
getsockname
()[
1
]
if
port
in
continue_prot
:
s
.
close
()
continue
ports
.
append
(
port
)
sockets
.
append
(
s
)
for
s
in
sockets
:
s
.
close
()
return
ports
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
...
...
@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_
prefill_
size
=
16384
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
...
@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
chunk_start
=
0
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_
prefill_
size
,
seq_length
)
chunk_end
=
min
(
chunk_start
+
chunk_size
,
seq_length
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_
prefill_
size
chunk_start
+=
chunk_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
...
...
merge_tensors/merge_safetensor_gguf.py
View file @
25cee581
...
...
@@ -3,7 +3,7 @@
import
os
# insert the path of the project
import
sys
sys
.
path
.
insert
(
0
,
"/home/azure/ktransformers"
)
#
sys.path.insert(0, "/home/azure/ktransformers")
import
argparse
import
torch
from
ktransformers.util.custom_gguf
import
GGUFLoader
,
translate_name_to_gguf
...
...
requirements-local_chat.txt
View file @
25cee581
...
...
@@ -6,4 +6,4 @@ packaging
cpufeature
protobuf
tiktoken
blobfile
\ No newline at end of file
blobfile
setup.py
View file @
25cee581
...
...
@@ -35,6 +35,8 @@ try:
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
except
ImportError
:
MUSA_HOME
=
None
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
class
CpuInstructInfo
:
CPU_INSTRUCT
=
os
.
getenv
(
"CPU_INSTRUCT"
,
"NATIVE"
)
...
...
@@ -212,7 +214,7 @@ class VersionInfo:
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
if
CUDA_HOME
is
not
None
:
backend_version
=
f
""
backend_version
=
f
"
cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
...
...
@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
class
CMakeExtension
(
Extension
):
def
__init__
(
self
,
name
:
str
,
sourcedir
:
str
=
""
)
->
None
:
def
__init__
(
self
,
name
:
str
,
sourcedir
:
str
)
->
None
:
super
().
__init__
(
name
,
sources
=
[])
self
.
sourcedir
=
os
.
fspath
(
Path
(
sourcedir
).
resolve
()
/
"ktransformers"
/
"ktransformers_ext"
)
print
(
name
,
sourcedir
)
self
.
sourcedir
=
sourcedir
class
CMakeBuild
(
BuildExtension
):
...
...
@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
f
"-DEXAMPLE_VERSION_INFO=
{
self
.
distribution
.
get_version
()
}
"
]
if
self
.
compiler
.
compiler_type
!=
"msvc"
:
if
not
cmake_generator
or
cmake_generator
==
"Ninja"
:
try
:
import
ninja
ninja_executable_path
=
Path
(
ninja
.
BIN_DIR
)
/
"ninja"
cmake_args
+=
[
"-GNinja"
,
f
"-DCMAKE_MAKE_PROGRAM:FILEPATH=
{
ninja_executable_path
}
"
,
]
except
ImportError
:
pass
pass
# try:
# import ninja
# ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
# cmake_args += [
# "-GNinja",
# f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
# ]
# except ImportError:
# pass
else
:
# Single config generators are handled "normally"
...
...
@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
build_args
+=
[
f
"--parallel=
{
cpu_count
}
"
]
print
(
"CMake args:"
,
cmake_args
)
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
print
(
"build_temp:"
,
build_temp
)
if
not
build_temp
.
exists
():
build_temp
.
mkdir
(
parents
=
True
)
result
=
subprocess
.
run
(
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
,
text
=
True
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard error:"
,
result
.
stderr
)
...
...
@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'
ktransformers
/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'
ktransformers
/ktransformers_ext/cuda/binding.cpp'
,
'
ktransformers
/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
'
csrc
/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'
csrc
/ktransformers_ext/cuda/binding.cpp'
,
'
csrc
/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
...
...
@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
}
)
elif
MUSA_HOME
is
not
None
:
SimplePorting
(
cuda_dir_path
=
"
ktransformers
/ktransformers_ext/cuda"
,
mapping_rule
=
{
SimplePorting
(
cuda_dir_path
=
"
csrc
/ktransformers_ext/cuda"
,
mapping_rule
=
{
# Common rules
"at::cuda"
:
"at::musa"
,
"#include <ATen/cuda/CUDAContext.h>"
:
"#include
\"
torch_musa/csrc/aten/musa/MUSAContext.h
\"
"
,
...
...
@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
"nv_bfloat16"
:
"mt_bfloat16"
,
}).
run
()
ops_module
=
MUSAExtension
(
'KTransformersOps'
,
[
'
ktransformers
/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
'
ktransformers
/ktransformers_ext/cuda_musa/binding.cpp'
,
'
csrc
/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
'
csrc
/ktransformers_ext/cuda_musa/binding.cpp'
,
# TODO: Add Marlin support for MUSA.
# '
ktransformers
/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
# '
csrc
/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args
=
{
'cxx'
:
[
'force_mcc'
],
...
...
@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
ops_module
,
CUDAExtension
(
'vLLMMarlin'
,
[
'csrc/custom_marlin/binding.cpp'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
},
)
]
if
with_balance
:
print
(
"using balance_serve"
)
ext_modules
.
append
(
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
)
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
),
ops_module
,
]
ext_modules
=
ext_modules
)
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment