Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
a7e8d7c1
Commit
a7e8d7c1
authored
Apr 13, 2025
by
Creeper-MZ
Browse files
updata function_call
parent
038db30e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
554 additions
and
89 deletions
+554
-89
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+392
-65
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+3
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+115
-10
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+44
-12
No files found.
ktransformers/server/api/openai/endpoints/chat.py
View file @
a7e8d7c1
import
json
from
time
import
time
from
uuid
import
uuid4
from
typing
import
Dict
,
List
,
Optional
,
Any
,
Literal
,
Union
from
pydantic
import
BaseModel
,
Field
import
re
from
fastapi
import
APIRouter
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
,
Role
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.log
import
logger
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
# Define own data structure instead of importing from OpenAI
class
CompletionUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
prompt_tokens_details
:
Optional
[
Dict
[
str
,
Any
]]
=
None
completion_tokens_details
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
Choice
(
BaseModel
):
index
:
int
message
:
Optional
[
Dict
[
str
,
Any
]]
=
None
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
Any
]
=
None
delta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
content_filter_results
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
ChatCompletion
(
BaseModel
):
id
:
str
object
:
str
=
"chat.completion"
created
:
int
model
:
str
choices
:
List
[
Choice
]
usage
:
Optional
[
CompletionUsage
]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
prompt_filter_results
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
# Only for non-streaming response construction
class
ChatCompletionMessageToolCallFunction
(
BaseModel
):
name
:
str
arguments
:
str
class
ChatCompletionMessageToolCall
(
BaseModel
):
id
:
str
type
:
str
function
:
ChatCompletionMessageToolCallFunction
class
ChatCompletionMessage
(
BaseModel
):
role
:
str
content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ChatCompletionMessageToolCall
]]
=
None
router
=
APIRouter
()
...
...
@@ -21,90 +63,375 @@ router = APIRouter()
async
def
list_models
():
return
{
"data"
:
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}],
"object"
:
"list"
}
def
getTools
(
buffer
):
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
extracted_tools
=
[]
working_buffer
=
buffer
# Iterate over all function calls
while
tool_call_begin_marker
in
working_buffer
and
tool_call_end_marker
in
working_buffer
:
# Find a complete function call
start_index
=
working_buffer
.
find
(
tool_call_begin_marker
)
end_index
=
working_buffer
.
find
(
tool_call_end_marker
)
+
len
(
tool_call_end_marker
)
if
start_index
==
-
1
or
end_index
==
-
1
or
start_index
>
end_index
:
logger
.
warning
(
"Not a function"
)
break
# Extract the full function call
full_tool_call
=
working_buffer
[
start_index
:
end_index
]
# Remove this function call from the working buffer to prevent duplicate processing
working_buffer
=
working_buffer
.
replace
(
full_tool_call
,
""
,
1
)
# Extract the function name
function_name_start
=
full_tool_call
.
find
(
tool_sep_marker
)
+
len
(
tool_sep_marker
)
function_name_end
=
full_tool_call
.
find
(
"
\n
"
,
function_name_start
)
function_name
=
full_tool_call
[
function_name_start
:
function_name_end
].
strip
()
# Extract JSON parameters
json_pattern
=
r
'```json\s*(.*?)\s*```'
json_match
=
re
.
search
(
json_pattern
,
full_tool_call
,
re
.
DOTALL
)
if
json_match
:
arguments_str
=
json_match
.
group
(
1
).
strip
()
# Generate tool call IDs
tool_call_id
=
f
"call_
{
uuid4
().
hex
[:
24
]
}
"
# Add to tool call list
extracted_tools
.
append
({
"id"
:
tool_call_id
,
"type"
:
"function"
,
"function"
:
{
"name"
:
function_name
,
"arguments"
:
arguments_str
}
})
logger
.
info
(
f
"Get Function:
{
function_name
}
"
)
else
:
logger
.
warning
(
f
"Unable to get function,function_name:
{
function_name
}
"
)
logger
.
info
(
f
"Total
{
len
(
extracted_tools
)
}
Functions"
)
return
extracted_tools
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
async
def
chat_completion
(
request
:
Request
,
create
:
ChatCompletionCreate
):
id
=
str
(
uuid4
())
async
def
chat_completion
(
request
:
Request
,
create
:
ChatCompletionCreate
):
id
=
str
(
uuid4
().
hex
)
# 1. Use system prompts to let models know how to use tools
enhanced_messages
=
list
(
create
.
messages
)
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
if
create
.
tools
and
len
(
create
.
tools
)
>
0
and
(
enhanced_messages
[
0
].
role
==
Role
.
system
or
enhanced_messages
[
0
].
role
==
Role
.
user
):
tool_instructions
=
"你可以使用function_call,函数调用功能,目前,你可以使用以下工具
\n\n
"
for
tool
in
create
.
tools
:
tool_instructions
+=
f
"
\"
function
\"
:{{
\"
name
\"
:
{
tool
.
function
.
name
}
,
\"
description
\"
:
{
tool
.
function
.
description
}
,
\"
parameters
\"
:
{
tool
.
function
.
parameters
}
}}
\n
"
# Modify tool usage guidelines to encourage JSON output
tool_instructions
+=
"name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数
\n
"
tool_instructions
+=
"工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具
\n
"
tool_instructions
+=
"
\n
当你需要使用工具时,请以下列格式输出,格式为:
\n
"
tool_instructions
+=
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name
\n
```json {"参数名": "参数值","参数名2": "参数值2"...}
\n
```<|tool▁call▁end|><|tool▁calls▁end|>
\n
'
tool_instructions
+=
'示例:
\n
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called
\n
```json {"arg1": "value1","arg2": "value2"}
\n
```<|tool▁call▁end|><|tool▁calls▁end|>
\n
'
tool_instructions
+=
"这样可以调用名为
\"
the_functnion_name_will_be_called
\"
,并将value1和value2传入参数arg1,arg2
\n
"
tool_instructions
+=
"不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
enhanced_messages
[
0
].
content
=
enhanced_messages
[
0
].
content
+
"
\n\n
"
+
tool_instructions
# Requests processed
interface
:
BackendInterfaceBase
=
get_interface
()
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
create
.
messages
]
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
enhanced_messages
]
if
Config
().
api_key
!=
''
:
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
chunk
=
ChatCompletionChunk
(
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
id
=
id
,
choices
=
[],
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
model
=
Config
().
model_name
,
system_fingerprint
=
f
"fp_
{
uuid4
().
hex
[:
12
]
}
"
,
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
# Collect the full output of the model, but specialize in processing tool calls
full_content
=
""
buffer
=
""
# Used to temporarily store the current block of text
tool_call_mode
=
False
# Mark if a tool call is being processed
tool_calls
=
[]
# Store all detected tool calls
# Customize model special tokens
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
#
at the end of inference, interface.inference() will return the usage of inference
#
Final return on utilization
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
yield
chunk
else
:
elif
isinstance
(
res
,
tuple
)
and
len
(
res
)
==
2
:
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
finish_reason
=
finish_reason
,
logprobs
=
None
,
)
chunk
.
choices
=
[
choice
]
yield
chunk
# Detecting model-specific formatting tool call starts
if
not
tool_call_mode
and
tool_calls_begin_marker
in
buffer
+
token
:
tool_call_mode
=
True
# Adjust full_content to remove tool call section
if
buffer
.
endswith
(
tool_calls_begin_marker
):
full_content
=
full_content
[:
-
len
(
tool_calls_begin_marker
)]
elif
tool_calls_begin_marker
in
(
buffer
+
token
):
idx
=
(
buffer
+
token
).
find
(
tool_calls_begin_marker
)
full_content
=
full_content
[:
-
(
len
(
buffer
)
-
idx
)]
buffer
=
""
# Send the current cumulative text content (if any)
if
full_content
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"content"
:
full_content
},
"finish_reason"
:
None
}]
yield
chunk
full_content
=
""
# Accumulation of content in non-tool call mode
if
not
tool_call_mode
:
full_content
+=
token
buffer
+=
token
# Keep the buffer at a reasonable size
if
len
(
buffer
)
>
200
:
buffer
=
buffer
[
-
200
:]
else
:
# In tool call mode, continue to collect tool call related text
buffer
+=
token
# If the tool call end marker is found
if
tool_calls_end_marker
in
buffer
:
try
:
# Parsing Calling Text Extraction Tool Calling Information
tool_calls
=
getTools
(
buffer
)
if
len
(
tool_calls
):
# reset state
tool_call_mode
=
False
buffer
=
""
# Send tool call events
for
idx
,
tool_call
in
enumerate
(
tool_calls
):
# First tool call message
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"role"
:
"assistant"
,
"content"
:
None
,
"tool_calls"
:
[{
"index"
:
idx
,
"id"
:
tool_call
[
"id"
],
"type"
:
"function"
,
"function"
:
{
"name"
:
tool_call
[
"function"
][
"name"
],
"arguments"
:
""
}
}]
},
"finish_reason"
:
None
}]
yield
chunk
# Sending Parameters
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"tool_calls"
:
[{
"index"
:
idx
,
"function"
:
{
"arguments"
:
tool_call
[
"function"
][
"arguments"
]}
}]
},
"finish_reason"
:
None
}]
yield
chunk
# Send Completion Message
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
"tool_calls"
}]
yield
chunk
# No further processing after return
return
else
:
# JSON extraction failed, probably incomplete formatting
logger
.
warning
(
"Failed to extract JSON from tool call"
)
tool_call_mode
=
False
buffer
=
""
except
Exception
as
e
:
logger
.
error
(
f
"Error processing tool call:
{
e
}
"
)
tool_call_mode
=
False
buffer
=
""
# Normal text output (only in non-tool call mode)
if
not
tool_call_mode
and
token
:
if
finish_reason
is
not
None
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
finish_reason
}]
yield
chunk
else
:
if
any
(
marker
in
token
for
marker
in
[
tool_calls_begin_marker
,
tool_call_begin_marker
]):
pass
else
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"content"
:
token
},
"finish_reason"
:
None
}]
yield
chunk
# If gotten this far without returning, it means that the full tool call was not detected
# Send Routine Completion Message
if
not
tool_call_mode
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
"stop"
}]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
else
:
from
openai.types.chat.chat_completion
import
Choice
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
content
=
""
# non streaming response processing
full_content
=
""
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
tool_calls
=
[]
buffer
=
""
tool_call_mode
=
False
# Custom model special markers
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
el
se
:
el
if
isinstance
(
res
,
tuple
)
and
len
(
res
)
==
2
:
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
choice
=
Choice
(
index
=
0
,
finish_reason
=
finish_reason
,
message
=
ChatCompletionMessage
(
content
=
content
,
role
=
"assistant"
))
chat_completion
=
ChatCompletion
(
id
=
id
,
choices
=
[
choice
],
created
=
int
(
time
()),
model
=
Config
().
model_name
,
object
=
'chat.completion'
,
usage
=
usage
)
return
chat_completion
# Detecting the start of model-specific formatting tool calls
if
not
tool_call_mode
and
tool_calls_begin_marker
in
buffer
+
token
:
tool_call_mode
=
True
# Adjust full_content to remove tool call section
if
buffer
.
endswith
(
tool_calls_begin_marker
):
full_content
=
full_content
[:
-
len
(
tool_calls_begin_marker
)]
elif
tool_calls_begin_marker
in
(
buffer
+
token
):
idx
=
(
buffer
+
token
).
find
(
tool_calls_begin_marker
)
full_content
=
full_content
[:
-
(
len
(
buffer
)
-
idx
)]
buffer
=
""
# Accumulation of content in non-tool call mode
if
not
tool_call_mode
:
full_content
+=
token
buffer
+=
token
# Keep the buffer at a reasonable size
if
len
(
buffer
)
>
200
:
buffer
=
buffer
[
-
200
:]
else
:
# In tool call mode, continue to collect tool call related text
buffer
+=
token
# If the tool call end marker is found
if
tool_calls_end_marker
in
buffer
:
try
:
# Parsing Calling Text Extraction Tool Calling Information
full_tool_call
=
buffer
# Extract function name
function_name_start
=
full_tool_call
.
find
(
tool_sep_marker
)
+
len
(
tool_sep_marker
)
function_name_end
=
full_tool_call
.
find
(
"
\n
"
,
function_name_start
)
function_name
=
full_tool_call
[
function_name_start
:
function_name_end
].
strip
()
# Extract JSON Parameters - Extracts the content between ```json and ```.
json_pattern
=
r
'```json\s*(.*?)\s*```'
json_match
=
re
.
search
(
json_pattern
,
full_tool_call
,
re
.
DOTALL
)
if
json_match
:
arguments_str
=
json_match
.
group
(
1
).
strip
()
# Generate tool call IDs
tool_call_id
=
f
"call_
{
uuid4
().
hex
[:
24
]
}
"
# Add to tool call list
tool_calls
.
append
({
"id"
:
tool_call_id
,
"index"
:
0
,
"type"
:
"function"
,
"function"
:
{
"name"
:
function_name
,
"arguments"
:
arguments_str
}
})
# If the tool call is successfully parsed, set the reason for completion
finish_reason
=
"tool_calls"
# reset state
tool_call_mode
=
False
buffer
=
""
else
:
# JSON extraction failed, probably incomplete formatting
logger
.
warning
(
"Failed to extract JSON from tool call"
)
tool_call_mode
=
False
buffer
=
""
except
Exception
as
e
:
logger
.
error
(
f
"Error processing tool call:
{
e
}
"
)
tool_call_mode
=
False
buffer
=
""
# Build Response
response
=
{
"id"
:
id
,
"object"
:
"chat.completion"
,
"created"
:
int
(
time
()),
"model"
:
Config
().
model_name
,
"choices"
:
[{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
None
if
tool_calls
else
full_content
,
"tool_calls"
:
tool_calls
if
tool_calls
else
None
},
"finish_reason"
:
finish_reason
or
"stop"
}],
"usage"
:
usage
.
__dict__
,
"system_fingerprint"
:
f
"fp_
{
uuid4
().
hex
[:
12
]
}
"
}
return
response
ktransformers/server/backend/interfaces/ktransformers.py
View file @
a7e8d7c1
import
torch
from
typing
import
Optional
,
List
import
asyncio
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
ktransformers.server.backend.interfaces.transformers
import
(
...
...
@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
tools
):
yield
v
# return this inference raw usage
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
a7e8d7c1
from
typing
import
Any
,
List
,
Optional
,
Set
import
re
import
json
import
uuid
from
transformers
import
(
LlamaTokenizer
,
AutoTokenizer
,
...
...
@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
# Check if tools are present
has_tools
=
tools
is
not
None
and
len
(
tools
)
>
0
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
...
...
@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
...
...
@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
# Handle tool calling
if
has_tools
:
# Start collecting tokens until we detect a tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
collected_tokens
+=
t
# Check if we're starting a tool call
if
not
is_collecting_tool_call
and
any
(
keyword
in
collected_tokens
.
lower
()
for
keyword
in
[
'"function"'
,
'function'
,
'tool_call'
,
'tool call'
]):
is_collecting_tool_call
=
True
# Generate a unique tool call ID
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
.
replace
(
'-'
,
''
)
}
"
# Send first tool call info
if
len
(
tools
)
>
0
and
hasattr
(
tools
[
0
],
'function'
)
and
hasattr
(
tools
[
0
].
function
,
'name'
):
# If tools are provided, use the first one's name
recommended_function
=
tools
[
0
].
function
.
name
else
:
# Otherwise try to extract from context
function_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
recommended_function
=
function_match
.
group
(
1
)
if
function_match
else
""
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'index'
:
0
,
'function'
:
{
'name'
:
recommended_function
,
'arguments'
:
""
}
},
'first_chunk'
:
True
}
# Extract function name if we're collecting tool call
if
is_collecting_tool_call
and
not
is_function_name_collected
:
name_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
if
name_match
:
function_name
=
name_match
.
group
(
1
)
is_function_name_collected
=
True
# Track argument collection
if
is_collecting_tool_call
and
is_function_name_collected
:
args_position
=
collected_tokens
.
find
(
'"arguments"'
)
if
args_position
>
-
1
:
# Find the start of the JSON object after "arguments":
json_start
=
collected_tokens
.
find
(
'{'
,
args_position
)
if
json_start
>
-
1
:
for
i
in
range
(
json_start
,
len
(
collected_tokens
)):
char
=
collected_tokens
[
i
]
collected_arguments
+=
char
if
char
==
'{'
:
brackets_count
+=
1
elif
char
==
'}'
:
brackets_count
-=
1
# Check if we've completed the arguments JSON
if
brackets_count
==
0
:
# Send argument chunk
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'function'
:
{
'name'
:
function_name
,
'arguments'
:
collected_arguments
}
},
'argument_chunk'
:
collected_arguments
,
'last_chunk'
:
True
,
'prompt_tokens'
:
176
,
'completion_tokens'
:
20
}
# Reset for next potential tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
break
# Handle finish reason
if
finish_reason
is
not
None
:
yield
""
,
finish_reason
print
(
""
)
else
:
# Regular text generation (no tools)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
ktransformers/server/schemas/endpoints/chat.py
View file @
a7e8d7c1
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
from
typing_extensions
import
Literal
from
enum
import
Enum
...
...
@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
from
uuid
import
uuid4
from
pydantic
import
BaseModel
,
Field
class
Role
(
Enum
):
system
=
'system'
...
...
@@ -17,26 +20,57 @@ class Role(Enum):
tool
=
'tool'
function
=
'function'
class
Message
(
BaseModel
):
content
:
str
role
:
Role
content
:
Optional
[
str
]
=
None
role
:
Role
name
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
tool_call_id
:
Optional
[
str
]
=
None
def
to_tokenizer_message
(
self
):
return
{
'content'
:
self
.
content
,
'role'
:
self
.
role
.
value
}
message
=
{
'role'
:
self
.
role
.
value
}
if
self
.
content
is
not
None
:
message
[
'content'
]
=
self
.
content
if
self
.
name
is
not
None
:
message
[
'name'
]
=
self
.
name
if
self
.
tool_calls
is
not
None
:
message
[
'tool_calls'
]
=
self
.
tool_calls
if
self
.
tool_call_id
is
not
None
:
message
[
'tool_call_id'
]
=
self
.
tool_call_id
return
message
class
FunctionParameters
(
BaseModel
):
type
:
str
=
"object"
properties
:
Dict
[
str
,
Any
]
=
{}
required
:
Optional
[
List
[
str
]]
=
None
class
FunctionDefinition
(
BaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
parameters
:
FunctionParameters
=
Field
(
default_factory
=
FunctionParameters
)
class
ToolFunction
(
BaseModel
):
function
:
FunctionDefinition
class
Tool
(
BaseModel
):
type
:
Literal
[
"function"
]
function
:
FunctionDefinition
class
ChatCompletionCreate
(
BaseModel
):
messages
:
List
[
Message
]
model
:
str
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
model
:
str
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
Field
(
default
=
0.6
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
tools
:
Optional
[
List
[
Tool
]]
=
None
tool_choice
:
Optional
[
Union
[
str
,
Dict
[
str
,
Any
]]]
=
None
stream_options
:
Optional
[
Dict
[
str
,
Any
]]
=
None
frequency_penalty
:
float
=
0
presence_penalty
:
float
=
0
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
choices
:
List
[
Choice
]
...
...
@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
prefill_time
:
float
decode_time
:
float
prefill_count
:
int
decode_count
:
int
decode_count
:
int
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment