Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
48558801
Unverified
Commit
48558801
authored
Apr 22, 2025
by
wang jiahao
Committed by
GitHub
Apr 22, 2025
Browse files
Merge pull request #1177 from kvcache-ai/update_param
Update param
parents
a1162eea
f5287e90
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
149 additions
and
163 deletions
+149
-163
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+5
-2
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+2
-2
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+3
-3
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+19
-10
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+13
-6
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+21
-119
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-3
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+6
-5
ktransformers/tests/function_call_test.py
ktransformers/tests/function_call_test.py
+45
-0
ktransformers/tests/test_client.py
ktransformers/tests/test_client.py
+31
-12
third_party/llamafile/iqk_mul_mat.inc
third_party/llamafile/iqk_mul_mat.inc
+1
-1
No files found.
ktransformers/models/custom_cache.py
View file @
48558801
...
...
@@ -12,7 +12,10 @@ import torch.nn as nn
import
transformers
from
transformers
import
Cache
,
PretrainedConfig
from
typing
import
List
,
Optional
,
Dict
,
Any
,
Tuple
from
ktransformers.server.balance_serve.settings
import
sched_ext
try
:
from
ktransformers.server.balance_serve.settings
import
sched_ext
except
:
print
(
"no balance_serve"
)
class
StaticCache
(
transformers
.
StaticCache
):
"""
Static Cache class to be used with `torch.compile(model)`.
...
...
@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self
.
v_caches
=
[]
def
load
(
self
,
inference_context
:
sched_ext
.
InferenceContext
):
def
load
(
self
,
inference_context
:
"
sched_ext.InferenceContext
"
):
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
self
.
k_caches
.
append
(
...
...
ktransformers/server/api/openai/endpoints/chat.py
View file @
48558801
...
...
@@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tools▁end>"
:
"<|tool▁calls▁end|>"
}
# Use check_client_connected for early stopping
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
# Final return on utilization
raw_usage
=
res
...
...
@@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tool▁end>"
:
"<|tool▁call▁end|>"
,
"<tools▁end>"
:
"<|tool▁calls▁end|>"
}
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
...
...
ktransformers/server/api/openai/legacy/completions.py
View file @
48558801
...
...
@@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage
router
=
APIRouter
()
@
router
.
post
(
"/completions"
,
tags
=
[
'openai'
])
async
def
create_completion
(
request
:
Request
,
create
:
CompletionCreate
):
async
def
create_completion
(
request
:
Request
,
create
:
CompletionCreate
):
id
=
str
(
uuid4
())
interface
=
get_interface
()
...
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
async
def
inner
():
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
...
...
@@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
...
...
ktransformers/server/backend/interfaces/balance_serve.py
View file @
48558801
...
...
@@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates
[
i
].
generated_token
=
generated_tokens
[
i
].
item
()
if
not
query_manager
.
query_map
[
query_updates
[
i
].
id
].
is_prefill
:
pos
=
query_updates
[
i
].
active_position
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
if
pos
<
query_manager
.
query_map
[
query_updates
[
i
].
id
].
max_length
:
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
def
report_last_time_performance
(
profiler
:
Profiler
):
try
:
...
...
@@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event
.
wait
()
def
get_sampling_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
def
get_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
"""Get sampling parameters and handle default values and edge cases"""
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_completion_tokens
=
self
.
args
.
max_new_tokens
else
:
max_completion_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
temperature
is
None
:
temperature
=
Config
()
.
temperature
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
Config
()
.
top_p
top_p
=
self
.
args
.
top_p
if
temperature
==
0
:
temperature
=
0.0001
if
top_p
==
0
:
top_p
=
0.0001
return
temperature
,
top_p
return
temperature
,
top_p
,
max_completion_tokens
def
run_queue_proxy
(
self
):
loop
=
asyncio
.
new_event_loop
()
...
...
@@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger
.
debug
(
f
"get input ids of shape
{
input_ids
.
shape
}
"
)
return
input_ids
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
profiler
=
Profiler
()
profiler
.
create_and_start_timer
(
"tokenize"
)
...
...
@@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria
=
[
self
.
tokenizer
.
encode
(
self
.
tokenizer
.
eos_token
,
add_special_tokens
=
False
),
self
.
tokenizer
.
encode
(
"<|im_end|>"
)]
query_add
.
stop_criteria
=
stop_criteria
temperature
,
top_p
=
self
.
get_
sampling_
params
(
temperature
,
top_p
)
temperature
,
top_p
,
max_new_tokens
=
self
.
get_params
(
temperature
,
top_p
,
max_tokens
,
max_completion_tokens
)
query_add
.
sample_options
.
temperature
=
temperature
query_add
.
sample_options
.
top_p
=
top_p
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
self
.
args
.
max_new_tokens
)
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
max_new_tokens
)
if
query_add
.
estimated_length
<
query_add
.
query_length
:
raise
Exception
(
f
'query too long: estimated_length=
{
query_add
.
estimated_length
}
< query_length=
{
query_add
.
query_length
}
'
)
query_id
=
self
.
sched_client
.
add_query
(
query_add
)
queue
=
asyncio
.
Queue
(
maxsize
=
self
.
args
.
max_new_tokens
)
queue
=
asyncio
.
Queue
(
maxsize
=
max_new_tokens
)
self
.
queue_map
[
query_id
]
=
queue
self
.
thread_map
[
thread_id
]
=
query_id
is_first_token
=
True
...
...
@@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler
.
pause_timer
(
"decode"
)
report_last_time_performance
(
profiler
)
yield
self
.
streamer
.
end
(),
None
if
profiler
.
get_counter
(
'decode'
)
>=
self
.
args
.
max_new_tokens
-
1
:
if
profiler
.
get_counter
(
'decode'
)
>=
max_new_tokens
-
1
:
yield
""
,
"length"
else
:
yield
""
,
"stop"
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
48558801
...
...
@@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_new_tokens
=
self
.
args
.
max_new_tokens
else
:
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
...
...
@@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
...
...
@@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
expected_length
=
min
(
self
.
seq_length
+
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
...
...
@@ -222,16 +228,17 @@ class KTransformersInterface(TransformersInterface):
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
self
.
max_new_tokens
=
min
(
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
yield
self
.
append_new_tokens
(
next_token
)
@
property
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tool
s
:
Optional
[
Lis
t
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_token
s
:
Optional
[
floa
t
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
tool
s
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
max_tokens
,
max_completion_token
s
):
yield
v
# return this inference raw usage
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
48558801
...
...
@@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_new_tokens
=
self
.
args
.
max_new_tokens
else
:
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
is_new
:
self
.
ever_generated_ids
.
clear
()
same_prefix
=
0
...
...
@@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
...
...
@@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
self
.
seq_length
+
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
...
...
@@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
self
.
max_new_tokens
=
min
(
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
def
generate
(
self
):
self
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
(),
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
i
in
range
(
1
,
self
.
max_new_tokens
):
...
...
@@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tool
s
:
Optional
[
Lis
t
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_token
s
:
Optional
[
floa
t
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
# Check if tools are present
has_tools
=
tools
is
not
None
and
len
(
tools
)
>
0
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
...
...
@@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
...
...
@@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
,
max_tokens
,
max_completion_tokens
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
# Handle tool calling
if
has_tools
:
# Start collecting tokens until we detect a tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
collected_tokens
+=
t
# Check if we're starting a tool call
if
not
is_collecting_tool_call
and
any
(
keyword
in
collected_tokens
.
lower
()
for
keyword
in
[
'"function"'
,
'function'
,
'tool_call'
,
'tool call'
]):
is_collecting_tool_call
=
True
# Generate a unique tool call ID
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
.
replace
(
'-'
,
''
)
}
"
# Send first tool call info
if
len
(
tools
)
>
0
and
hasattr
(
tools
[
0
],
'function'
)
and
hasattr
(
tools
[
0
].
function
,
'name'
):
# If tools are provided, use the first one's name
recommended_function
=
tools
[
0
].
function
.
name
else
:
# Otherwise try to extract from context
function_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
recommended_function
=
function_match
.
group
(
1
)
if
function_match
else
""
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'index'
:
0
,
'function'
:
{
'name'
:
recommended_function
,
'arguments'
:
""
}
},
'first_chunk'
:
True
}
# Extract function name if we're collecting tool call
if
is_collecting_tool_call
and
not
is_function_name_collected
:
name_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
if
name_match
:
function_name
=
name_match
.
group
(
1
)
is_function_name_collected
=
True
# Track argument collection
if
is_collecting_tool_call
and
is_function_name_collected
:
args_position
=
collected_tokens
.
find
(
'"arguments"'
)
if
args_position
>
-
1
:
# Find the start of the JSON object after "arguments":
json_start
=
collected_tokens
.
find
(
'{'
,
args_position
)
if
json_start
>
-
1
:
for
i
in
range
(
json_start
,
len
(
collected_tokens
)):
char
=
collected_tokens
[
i
]
collected_arguments
+=
char
if
char
==
'{'
:
brackets_count
+=
1
elif
char
==
'}'
:
brackets_count
-=
1
# Check if we've completed the arguments JSON
if
brackets_count
==
0
:
# Send argument chunk
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'function'
:
{
'name'
:
function_name
,
'arguments'
:
collected_arguments
}
},
'argument_chunk'
:
collected_arguments
,
'last_chunk'
:
True
,
'prompt_tokens'
:
176
,
'completion_tokens'
:
20
}
# Reset for next potential tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
break
# Handle finish reason
if
finish_reason
is
not
None
:
yield
""
,
finish_reason
print
(
""
)
else
:
# Regular text generation (no tools)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/schemas/endpoints/chat.py
View file @
48558801
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
,
Field
from
ktransformers.server.schemas.base
import
Object
...
...
@@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice
from
uuid
import
uuid4
from
pydantic
import
BaseModel
,
Field
class
Role
(
Enum
):
system
=
'system'
...
...
@@ -67,7 +65,9 @@ class ChatCompletionCreate(BaseModel):
stream_options
:
Optional
[
Dict
[
str
,
Any
]]
=
None
frequency_penalty
:
float
=
0
presence_penalty
:
float
=
0
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
max_completion_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
48558801
from
typing
import
List
,
Optional
from
enum
import
Enum
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
Field
from
..base
import
Object
...
...
@@ -9,9 +8,11 @@ class CompletionCreate(BaseModel):
model
:
str
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
temperature
:
Optional
[
float
]
=
Field
(
default
=
0.6
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1
)
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
max_completion_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
self
.
get_tokenizer_messages
(
'
\n
'
.
join
(
self
.
prompt
))
...
...
ktransformers/tests/function_call_test.py
0 → 100644
View file @
48558801
from
openai
import
OpenAI
def
send_messages
(
messages
):
response
=
client
.
chat
.
completions
.
create
(
model
=
"deepseek-chat"
,
messages
=
messages
,
tools
=
tools
)
return
response
.
choices
[
0
].
message
client
=
OpenAI
(
api_key
=
"placeholder"
,
base_url
=
"http://0.0.0.0:10002/v1"
,
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get weather of an location, the user shoud supply a location first"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
,
"description"
:
"The city and state, e.g. San Francisco, CA"
,
}
},
"required"
:
[
"location"
]
},
}
},
]
messages
=
[{
"role"
:
"user"
,
"content"
:
"How's the weather in Hangzhou?"
}]
message
=
send_messages
(
messages
)
print
(
f
"User>
\t
{
messages
[
0
][
'content'
]
}
"
)
print
(
message
)
tool
=
message
.
tool_calls
[
0
]
messages
.
append
(
message
)
messages
.
append
({
"role"
:
"tool"
,
"tool_call_id"
:
tool
.
id
,
"content"
:
"24℃"
})
message
=
send_messages
(
messages
)
print
(
f
"Model>
\t
{
message
.
content
}
"
)
\ No newline at end of file
ktransformers/tests/test_client.py
View file @
48558801
...
...
@@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list
=
[
1
]
decodesz_list
=
[
128
]
prompt_list
=
[
'Please elaborate on modern world history.'
,
'Please introduce Harry Potter.'
,
'I want to learn Python. Please give me some advice.'
,
'Please tell me a joke '
]
async
def
fetch_event_stream
(
session
,
request_id
):
async
def
fetch_event_stream
(
session
,
payload
,
request_id
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
request_id
]}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
...
...
@@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id):
async
def
main
(
prompt_id
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
prompt_id
)]
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
prompt_id
]}
],
"model"
:
"DeepSeek-V3"
,
"stream"
:
True
,
"max_completion_tokens"
:
2
,
# "temperature": 0.3,
# "top_p": 1.0,
# "max_tokens" : 20,
}
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"temperature"
]
=
0.3
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"top_p"
]
=
1
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"max_tokens"
]
=
200
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"stream"
]
=
False
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
...
...
third_party/llamafile/iqk_mul_mat.inc
View file @
48558801
...
...
@@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
default
:
{
printf
(
"case:%d"
,
typeA
);
//
printf("case:%d",typeA);
return
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment