Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
03a65d6b
Commit
03a65d6b
authored
Apr 21, 2025
by
qiyuxinlin
Browse files
roll back ktransformers backend, add max_tokens, max_completion_tokens param
parent
a1162eea
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
144 additions
and
161 deletions
+144
-161
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+2
-2
ktransformers/server/api/openai/legacy/completions.py
ktransformers/server/api/openai/legacy/completions.py
+3
-3
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+19
-10
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+13
-6
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+21
-119
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-3
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+6
-5
ktransformers/tests/function_call_test.py
ktransformers/tests/function_call_test.py
+45
-0
ktransformers/tests/test_client.py
ktransformers/tests/test_client.py
+31
-12
third_party/llamafile/iqk_mul_mat.inc
third_party/llamafile/iqk_mul_mat.inc
+1
-1
No files found.
ktransformers/server/api/openai/endpoints/chat.py
View file @
03a65d6b
...
...
@@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tools▁end>"
:
"<|tool▁calls▁end|>"
}
# Use check_client_connected for early stopping
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
# Final return on utilization
raw_usage
=
res
...
...
@@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tool▁end>"
:
"<|tool▁call▁end|>"
,
"<tools▁end>"
:
"<|tool▁calls▁end|>"
}
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
...
...
ktransformers/server/api/openai/legacy/completions.py
View file @
03a65d6b
...
...
@@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage
router
=
APIRouter
()
@
router
.
post
(
"/completions"
,
tags
=
[
'openai'
])
async
def
create_completion
(
request
:
Request
,
create
:
CompletionCreate
):
async
def
create_completion
(
request
:
Request
,
create
:
CompletionCreate
):
id
=
str
(
uuid4
())
interface
=
get_interface
()
...
...
@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if
create
.
stream
:
async
def
inner
():
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
...
...
@@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate):
return
stream_response
(
request
,
inner
())
else
:
comp
=
CompletionObject
(
id
=
id
,
object
=
'text_completion'
,
created
=
int
(
time
()))
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
):
async
for
res
in
interface
.
inference
(
create
.
prompt
,
id
,
create
.
temperature
,
create
.
top_p
,
create
.
max_tokens
,
create
.
max_completion_tokens
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
else
:
...
...
ktransformers/server/backend/interfaces/balance_serve.py
View file @
03a65d6b
...
...
@@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates
[
i
].
generated_token
=
generated_tokens
[
i
].
item
()
if
not
query_manager
.
query_map
[
query_updates
[
i
].
id
].
is_prefill
:
pos
=
query_updates
[
i
].
active_position
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
if
pos
<
query_manager
.
query_map
[
query_updates
[
i
].
id
].
max_length
:
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
def
report_last_time_performance
(
profiler
:
Profiler
):
try
:
...
...
@@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event
.
wait
()
def
get_sampling_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
def
get_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
"""Get sampling parameters and handle default values and edge cases"""
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_completion_tokens
=
self
.
args
.
max_new_tokens
else
:
max_completion_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
temperature
is
None
:
temperature
=
Config
()
.
temperature
temperature
=
self
.
args
.
temperature
if
top_p
is
None
:
top_p
=
Config
()
.
top_p
top_p
=
self
.
args
.
top_p
if
temperature
==
0
:
temperature
=
0.0001
if
top_p
==
0
:
top_p
=
0.0001
return
temperature
,
top_p
return
temperature
,
top_p
,
max_completion_tokens
def
run_queue_proxy
(
self
):
loop
=
asyncio
.
new_event_loop
()
...
...
@@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger
.
debug
(
f
"get input ids of shape
{
input_ids
.
shape
}
"
)
return
input_ids
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
profiler
=
Profiler
()
profiler
.
create_and_start_timer
(
"tokenize"
)
...
...
@@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria
=
[
self
.
tokenizer
.
encode
(
self
.
tokenizer
.
eos_token
,
add_special_tokens
=
False
),
self
.
tokenizer
.
encode
(
"<|im_end|>"
)]
query_add
.
stop_criteria
=
stop_criteria
temperature
,
top_p
=
self
.
get_
sampling_
params
(
temperature
,
top_p
)
temperature
,
top_p
,
max_new_tokens
=
self
.
get_params
(
temperature
,
top_p
,
max_tokens
,
max_completion_tokens
)
query_add
.
sample_options
.
temperature
=
temperature
query_add
.
sample_options
.
top_p
=
top_p
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
self
.
args
.
max_new_tokens
)
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
max_new_tokens
)
if
query_add
.
estimated_length
<
query_add
.
query_length
:
raise
Exception
(
f
'query too long: estimated_length=
{
query_add
.
estimated_length
}
< query_length=
{
query_add
.
query_length
}
'
)
query_id
=
self
.
sched_client
.
add_query
(
query_add
)
queue
=
asyncio
.
Queue
(
maxsize
=
self
.
args
.
max_new_tokens
)
queue
=
asyncio
.
Queue
(
maxsize
=
max_new_tokens
)
self
.
queue_map
[
query_id
]
=
queue
self
.
thread_map
[
thread_id
]
=
query_id
is_first_token
=
True
...
...
@@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler
.
pause_timer
(
"decode"
)
report_last_time_performance
(
profiler
)
yield
self
.
streamer
.
end
(),
None
if
profiler
.
get_counter
(
'decode'
)
>=
self
.
args
.
max_new_tokens
-
1
:
if
profiler
.
get_counter
(
'decode'
)
>=
max_new_tokens
-
1
:
yield
""
,
"length"
else
:
yield
""
,
"stop"
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
03a65d6b
...
...
@@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_new_tokens
=
self
.
args
.
max_new_tokens
else
:
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
...
...
@@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
...
...
@@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
expected_length
=
min
(
self
.
seq_length
+
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
...
...
@@ -222,16 +228,17 @@ class KTransformersInterface(TransformersInterface):
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
self
.
max_new_tokens
=
min
(
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
yield
self
.
append_new_tokens
(
next_token
)
@
property
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tool
s
:
Optional
[
Lis
t
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_token
s
:
Optional
[
floa
t
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
tool
s
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
max_tokens
,
max_completion_token
s
):
yield
v
# return this inference raw usage
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
03a65d6b
...
...
@@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_tokens
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
if
max_tokens
is
not
None
:
max_completion_tokens
=
max_tokens
if
max_completion_tokens
is
None
:
max_new_tokens
=
self
.
args
.
max_new_tokens
else
:
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
max_completion_tokens
)
if
is_new
:
self
.
ever_generated_ids
.
clear
()
same_prefix
=
0
...
...
@@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
if
getattr
(
self
,
'generated_ids'
,
None
)
is
None
:
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
input_ids
.
shape
[
-
1
]
+
self
.
args
.
max_new_tokens
+
1
,
input_ids
.
shape
[
-
1
]
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
...
...
@@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
self
.
seq_length
+
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
...
...
@@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
self
.
max_new_tokens
=
min
(
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
def
generate
(
self
):
self
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
-
1
logger
.
info
(
f
"args.max_new_tokens:
{
self
.
args
.
max_new_tokens
}
, cache_lens:
{
self
.
args
.
cache_lens
}
, seq_length:
{
self
.
seq_length
}
"
)
if
(
self
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
(),
"length"
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
i
in
range
(
1
,
self
.
max_new_tokens
):
...
...
@@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tool
s
:
Optional
[
Lis
t
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
max_tokens
:
Optional
[
float
]
=
None
,
max_completion_token
s
:
Optional
[
floa
t
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
# Check if tools are present
has_tools
=
tools
is
not
None
and
len
(
tools
)
>
0
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
...
...
@@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
...
...
@@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
,
max_tokens
,
max_completion_tokens
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
# Handle tool calling
if
has_tools
:
# Start collecting tokens until we detect a tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
collected_tokens
+=
t
# Check if we're starting a tool call
if
not
is_collecting_tool_call
and
any
(
keyword
in
collected_tokens
.
lower
()
for
keyword
in
[
'"function"'
,
'function'
,
'tool_call'
,
'tool call'
]):
is_collecting_tool_call
=
True
# Generate a unique tool call ID
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
.
replace
(
'-'
,
''
)
}
"
# Send first tool call info
if
len
(
tools
)
>
0
and
hasattr
(
tools
[
0
],
'function'
)
and
hasattr
(
tools
[
0
].
function
,
'name'
):
# If tools are provided, use the first one's name
recommended_function
=
tools
[
0
].
function
.
name
else
:
# Otherwise try to extract from context
function_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
recommended_function
=
function_match
.
group
(
1
)
if
function_match
else
""
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'index'
:
0
,
'function'
:
{
'name'
:
recommended_function
,
'arguments'
:
""
}
},
'first_chunk'
:
True
}
# Extract function name if we're collecting tool call
if
is_collecting_tool_call
and
not
is_function_name_collected
:
name_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
if
name_match
:
function_name
=
name_match
.
group
(
1
)
is_function_name_collected
=
True
# Track argument collection
if
is_collecting_tool_call
and
is_function_name_collected
:
args_position
=
collected_tokens
.
find
(
'"arguments"'
)
if
args_position
>
-
1
:
# Find the start of the JSON object after "arguments":
json_start
=
collected_tokens
.
find
(
'{'
,
args_position
)
if
json_start
>
-
1
:
for
i
in
range
(
json_start
,
len
(
collected_tokens
)):
char
=
collected_tokens
[
i
]
collected_arguments
+=
char
if
char
==
'{'
:
brackets_count
+=
1
elif
char
==
'}'
:
brackets_count
-=
1
# Check if we've completed the arguments JSON
if
brackets_count
==
0
:
# Send argument chunk
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'function'
:
{
'name'
:
function_name
,
'arguments'
:
collected_arguments
}
},
'argument_chunk'
:
collected_arguments
,
'last_chunk'
:
True
,
'prompt_tokens'
:
176
,
'completion_tokens'
:
20
}
# Reset for next potential tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
break
# Handle finish reason
if
finish_reason
is
not
None
:
yield
""
,
finish_reason
print
(
""
)
else
:
# Regular text generation (no tools)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/schemas/endpoints/chat.py
View file @
03a65d6b
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
from
typing_extensions
import
Literal
from
enum
import
Enum
from
pydantic
import
BaseModel
,
Field
from
ktransformers.server.schemas.base
import
Object
...
...
@@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice
from
uuid
import
uuid4
from
pydantic
import
BaseModel
,
Field
class
Role
(
Enum
):
system
=
'system'
...
...
@@ -67,7 +65,9 @@ class ChatCompletionCreate(BaseModel):
stream_options
:
Optional
[
Dict
[
str
,
Any
]]
=
None
frequency_penalty
:
float
=
0
presence_penalty
:
float
=
0
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
max_completion_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
03a65d6b
from
typing
import
List
,
Optional
from
enum
import
Enum
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
Field
from
..base
import
Object
...
...
@@ -9,9 +8,11 @@ class CompletionCreate(BaseModel):
model
:
str
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
temperature
:
Optional
[
float
]
=
Field
(
default
=
0.6
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1
)
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
max_completion_tokens
:
Optional
[
int
]
=
Field
(
default
=
50
)
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
self
.
get_tokenizer_messages
(
'
\n
'
.
join
(
self
.
prompt
))
...
...
ktransformers/tests/function_call_test.py
0 → 100644
View file @
03a65d6b
from
openai
import
OpenAI
def
send_messages
(
messages
):
response
=
client
.
chat
.
completions
.
create
(
model
=
"deepseek-chat"
,
messages
=
messages
,
tools
=
tools
)
return
response
.
choices
[
0
].
message
client
=
OpenAI
(
api_key
=
"placeholder"
,
base_url
=
"http://0.0.0.0:10002/v1"
,
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get weather of an location, the user shoud supply a location first"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
,
"description"
:
"The city and state, e.g. San Francisco, CA"
,
}
},
"required"
:
[
"location"
]
},
}
},
]
messages
=
[{
"role"
:
"user"
,
"content"
:
"How's the weather in Hangzhou?"
}]
message
=
send_messages
(
messages
)
print
(
f
"User>
\t
{
messages
[
0
][
'content'
]
}
"
)
print
(
message
)
tool
=
message
.
tool_calls
[
0
]
messages
.
append
(
message
)
messages
.
append
({
"role"
:
"tool"
,
"tool_call_id"
:
tool
.
id
,
"content"
:
"24℃"
})
message
=
send_messages
(
messages
)
print
(
f
"Model>
\t
{
message
.
content
}
"
)
\ No newline at end of file
ktransformers/tests/test_client.py
View file @
03a65d6b
...
...
@@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list
=
[
1
]
decodesz_list
=
[
128
]
prompt_list
=
[
'Please elaborate on modern world history.'
,
'Please introduce Harry Potter.'
,
'I want to learn Python. Please give me some advice.'
,
'Please tell me a joke '
]
async
def
fetch_event_stream
(
session
,
request_id
):
async
def
fetch_event_stream
(
session
,
payload
,
request_id
):
try
:
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
request_id
]}
],
"model"
:
"DeepSeek-V3"
,
"temperature"
:
0.3
,
"top_p"
:
1.0
,
"stream"
:
True
# 开启流式输出
}
headers
=
{
'accept'
:
'application/json'
,
...
...
@@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id):
async
def
main
(
prompt_id
):
async
with
aiohttp
.
ClientSession
()
as
session
:
tasks
=
[
fetch_event_stream
(
session
,
prompt_id
)]
payload
=
{
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
prompt_list
[
prompt_id
]}
],
"model"
:
"DeepSeek-V3"
,
"stream"
:
True
,
"max_completion_tokens"
:
2
,
# "temperature": 0.3,
# "top_p": 1.0,
# "max_tokens" : 20,
}
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"temperature"
]
=
0.3
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"top_p"
]
=
1
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"max_tokens"
]
=
200
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
payload
[
"stream"
]
=
False
tasks
=
[
fetch_event_stream
(
session
,
payload
,
prompt_id
)]
await
asyncio
.
gather
(
*
tasks
)
if
__name__
==
"__main__"
:
...
...
third_party/llamafile/iqk_mul_mat.inc
View file @
03a65d6b
...
...
@@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
default
:
{
printf
(
"case:%d"
,
typeA
);
//
printf("case:%d",typeA);
return
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment