Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
a7e8d7c1
Commit
a7e8d7c1
authored
Apr 13, 2025
by
Creeper-MZ
Browse files
updata function_call
parent
038db30e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
554 additions
and
89 deletions
+554
-89
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+392
-65
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+3
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+115
-10
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+44
-12
No files found.
ktransformers/server/api/openai/endpoints/chat.py
View file @
a7e8d7c1
import
json
import
json
from
time
import
time
from
time
import
time
from
uuid
import
uuid4
from
uuid
import
uuid4
from
typing
import
Dict
,
List
,
Optional
,
Any
,
Literal
,
Union
from
pydantic
import
BaseModel
,
Field
import
re
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
fastapi.requests
import
Request
from
fastapi.requests
import
Request
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.assistants.streaming
import
chat_stream_response
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionCreate
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
,
Role
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.backend.base
import
BackendInterfaceBase
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.log
import
logger
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
ktransformers.server.schemas.endpoints.chat
import
ChatCompletionChunk
from
openai.types.chat
import
ChatCompletion
from
openai.types.completion_usage
import
CompletionUsage
# Define own data structure instead of importing from OpenAI
class
CompletionUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
prompt_tokens_details
:
Optional
[
Dict
[
str
,
Any
]]
=
None
completion_tokens_details
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
Choice
(
BaseModel
):
index
:
int
message
:
Optional
[
Dict
[
str
,
Any
]]
=
None
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
Any
]
=
None
delta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
content_filter_results
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
ChatCompletion
(
BaseModel
):
id
:
str
object
:
str
=
"chat.completion"
created
:
int
model
:
str
choices
:
List
[
Choice
]
usage
:
Optional
[
CompletionUsage
]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
prompt_filter_results
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
# Only for non-streaming response construction
class
ChatCompletionMessageToolCallFunction
(
BaseModel
):
name
:
str
arguments
:
str
class
ChatCompletionMessageToolCall
(
BaseModel
):
id
:
str
type
:
str
function
:
ChatCompletionMessageToolCallFunction
class
ChatCompletionMessage
(
BaseModel
):
role
:
str
content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ChatCompletionMessageToolCall
]]
=
None
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -21,90 +63,375 @@ router = APIRouter()
...
@@ -21,90 +63,375 @@ router = APIRouter()
async
def
list_models
():
async
def
list_models
():
return
{
"data"
:
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}],
"object"
:
"list"
}
return
{
"data"
:
[{
"id"
:
Config
().
model_name
,
"name"
:
Config
().
model_name
}],
"object"
:
"list"
}
def
getTools
(
buffer
):
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
extracted_tools
=
[]
working_buffer
=
buffer
# Iterate over all function calls
while
tool_call_begin_marker
in
working_buffer
and
tool_call_end_marker
in
working_buffer
:
# Find a complete function call
start_index
=
working_buffer
.
find
(
tool_call_begin_marker
)
end_index
=
working_buffer
.
find
(
tool_call_end_marker
)
+
len
(
tool_call_end_marker
)
if
start_index
==
-
1
or
end_index
==
-
1
or
start_index
>
end_index
:
logger
.
warning
(
"Not a function"
)
break
# Extract the full function call
full_tool_call
=
working_buffer
[
start_index
:
end_index
]
# Remove this function call from the working buffer to prevent duplicate processing
working_buffer
=
working_buffer
.
replace
(
full_tool_call
,
""
,
1
)
# Extract the function name
function_name_start
=
full_tool_call
.
find
(
tool_sep_marker
)
+
len
(
tool_sep_marker
)
function_name_end
=
full_tool_call
.
find
(
"
\n
"
,
function_name_start
)
function_name
=
full_tool_call
[
function_name_start
:
function_name_end
].
strip
()
# Extract JSON parameters
json_pattern
=
r
'```json\s*(.*?)\s*```'
json_match
=
re
.
search
(
json_pattern
,
full_tool_call
,
re
.
DOTALL
)
if
json_match
:
arguments_str
=
json_match
.
group
(
1
).
strip
()
# Generate tool call IDs
tool_call_id
=
f
"call_
{
uuid4
().
hex
[:
24
]
}
"
# Add to tool call list
extracted_tools
.
append
({
"id"
:
tool_call_id
,
"type"
:
"function"
,
"function"
:
{
"name"
:
function_name
,
"arguments"
:
arguments_str
}
})
logger
.
info
(
f
"Get Function:
{
function_name
}
"
)
else
:
logger
.
warning
(
f
"Unable to get function,function_name:
{
function_name
}
"
)
logger
.
info
(
f
"Total
{
len
(
extracted_tools
)
}
Functions"
)
return
extracted_tools
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
@
router
.
post
(
'/chat/completions'
,
tags
=
[
'openai'
])
async
def
chat_completion
(
request
:
Request
,
create
:
ChatCompletionCreate
):
async
def
chat_completion
(
request
:
Request
,
create
:
ChatCompletionCreate
):
id
=
str
(
uuid4
())
id
=
str
(
uuid4
().
hex
)
# 1. Use system prompts to let models know how to use tools
enhanced_messages
=
list
(
create
.
messages
)
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
if
create
.
tools
and
len
(
create
.
tools
)
>
0
and
(
enhanced_messages
[
0
].
role
==
Role
.
system
or
enhanced_messages
[
0
].
role
==
Role
.
user
):
tool_instructions
=
"你可以使用function_call,函数调用功能,目前,你可以使用以下工具
\n\n
"
for
tool
in
create
.
tools
:
tool_instructions
+=
f
"
\"
function
\"
:{{
\"
name
\"
:
{
tool
.
function
.
name
}
,
\"
description
\"
:
{
tool
.
function
.
description
}
,
\"
parameters
\"
:
{
tool
.
function
.
parameters
}
}}
\n
"
# Modify tool usage guidelines to encourage JSON output
tool_instructions
+=
"name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数
\n
"
tool_instructions
+=
"工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具
\n
"
tool_instructions
+=
"
\n
当你需要使用工具时,请以下列格式输出,格式为:
\n
"
tool_instructions
+=
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name
\n
```json {"参数名": "参数值","参数名2": "参数值2"...}
\n
```<|tool▁call▁end|><|tool▁calls▁end|>
\n
'
tool_instructions
+=
'示例:
\n
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called
\n
```json {"arg1": "value1","arg2": "value2"}
\n
```<|tool▁call▁end|><|tool▁calls▁end|>
\n
'
tool_instructions
+=
"这样可以调用名为
\"
the_functnion_name_will_be_called
\"
,并将value1和value2传入参数arg1,arg2
\n
"
tool_instructions
+=
"不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
enhanced_messages
[
0
].
content
=
enhanced_messages
[
0
].
content
+
"
\n\n
"
+
tool_instructions
# Requests processed
interface
:
BackendInterfaceBase
=
get_interface
()
interface
:
BackendInterfaceBase
=
get_interface
()
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
enhanced_messages
]
input_message
=
[
json
.
loads
(
m
.
model_dump_json
())
for
m
in
create
.
messages
]
if
Config
().
api_key
!=
''
:
if
Config
().
api_key
!=
''
:
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
assert
request
.
headers
.
get
(
'Authorization'
,
''
).
split
()[
-
1
]
==
Config
().
api_key
if
create
.
stream
:
if
create
.
stream
:
from
openai.types.chat.chat_completion_chunk
import
Choice
,
ChoiceDelta
async
def
inner
():
async
def
inner
():
chunk
=
ChatCompletionChunk
(
chunk
=
ChatCompletionChunk
(
id
=
id
,
id
=
id
,
choices
=
[],
choices
=
[],
object
=
'chat.completion.chunk'
,
object
=
'chat.completion.chunk'
,
created
=
int
(
time
()),
created
=
int
(
time
()),
model
=
Config
().
model_name
,
model
=
Config
().
model_name
,
system_fingerprint
=
f
"fp_
{
uuid4
().
hex
[:
12
]
}
"
,
)
)
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
# Collect the full output of the model, but specialize in processing tool calls
full_content
=
""
buffer
=
""
# Used to temporarily store the current block of text
tool_call_mode
=
False
# Mark if a tool call is being processed
tool_calls
=
[]
# Store all detected tool calls
# Customize model special tokens
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
if
isinstance
(
res
,
RawUsage
):
#
at the end of inference, interface.inference() will return the usage of inference
#
Final return on utilization
raw_usage
=
res
raw_usage
=
res
chunk
.
choices
=
[]
chunk
.
choices
=
[]
chunk
.
usage
=
CompletionUsage
(
chunk
.
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
)
yield
chunk
yield
chunk
elif
isinstance
(
res
,
tuple
)
and
len
(
res
)
==
2
:
else
:
token
,
finish_reason
=
res
token
,
finish_reason
=
res
choice
=
Choice
(
index
=
0
,
# Detecting model-specific formatting tool call starts
delta
=
ChoiceDelta
(
content
=
token
,
role
=
None
,
tool_calls
=
None
),
if
not
tool_call_mode
and
tool_calls_begin_marker
in
buffer
+
token
:
finish_reason
=
finish_reason
,
tool_call_mode
=
True
logprobs
=
None
,
)
# Adjust full_content to remove tool call section
chunk
.
choices
=
[
choice
]
if
buffer
.
endswith
(
tool_calls_begin_marker
):
yield
chunk
full_content
=
full_content
[:
-
len
(
tool_calls_begin_marker
)]
elif
tool_calls_begin_marker
in
(
buffer
+
token
):
idx
=
(
buffer
+
token
).
find
(
tool_calls_begin_marker
)
full_content
=
full_content
[:
-
(
len
(
buffer
)
-
idx
)]
buffer
=
""
# Send the current cumulative text content (if any)
if
full_content
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"content"
:
full_content
},
"finish_reason"
:
None
}]
yield
chunk
full_content
=
""
# Accumulation of content in non-tool call mode
if
not
tool_call_mode
:
full_content
+=
token
buffer
+=
token
# Keep the buffer at a reasonable size
if
len
(
buffer
)
>
200
:
buffer
=
buffer
[
-
200
:]
else
:
# In tool call mode, continue to collect tool call related text
buffer
+=
token
# If the tool call end marker is found
if
tool_calls_end_marker
in
buffer
:
try
:
# Parsing Calling Text Extraction Tool Calling Information
tool_calls
=
getTools
(
buffer
)
if
len
(
tool_calls
):
# reset state
tool_call_mode
=
False
buffer
=
""
# Send tool call events
for
idx
,
tool_call
in
enumerate
(
tool_calls
):
# First tool call message
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"role"
:
"assistant"
,
"content"
:
None
,
"tool_calls"
:
[{
"index"
:
idx
,
"id"
:
tool_call
[
"id"
],
"type"
:
"function"
,
"function"
:
{
"name"
:
tool_call
[
"function"
][
"name"
],
"arguments"
:
""
}
}]
},
"finish_reason"
:
None
}]
yield
chunk
# Sending Parameters
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"tool_calls"
:
[{
"index"
:
idx
,
"function"
:
{
"arguments"
:
tool_call
[
"function"
][
"arguments"
]}
}]
},
"finish_reason"
:
None
}]
yield
chunk
# Send Completion Message
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
"tool_calls"
}]
yield
chunk
# No further processing after return
return
else
:
# JSON extraction failed, probably incomplete formatting
logger
.
warning
(
"Failed to extract JSON from tool call"
)
tool_call_mode
=
False
buffer
=
""
except
Exception
as
e
:
logger
.
error
(
f
"Error processing tool call:
{
e
}
"
)
tool_call_mode
=
False
buffer
=
""
# Normal text output (only in non-tool call mode)
if
not
tool_call_mode
and
token
:
if
finish_reason
is
not
None
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
finish_reason
}]
yield
chunk
else
:
if
any
(
marker
in
token
for
marker
in
[
tool_calls_begin_marker
,
tool_call_begin_marker
]):
pass
else
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{
"content"
:
token
},
"finish_reason"
:
None
}]
yield
chunk
# If gotten this far without returning, it means that the full tool call was not detected
# Send Routine Completion Message
if
not
tool_call_mode
:
chunk
.
choices
=
[{
"index"
:
0
,
"delta"
:
{},
"finish_reason"
:
"stop"
}]
yield
chunk
return
chat_stream_response
(
request
,
inner
())
return
chat_stream_response
(
request
,
inner
())
else
:
else
:
from
openai.types.chat.chat_completion
import
Choice
# non streaming response processing
from
openai.types.chat.chat_completion_message
import
ChatCompletionMessage
full_content
=
""
content
=
""
finish_reason
=
None
finish_reason
=
None
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
tool_calls
=
[]
buffer
=
""
tool_call_mode
=
False
# Custom model special markers
tool_calls_begin_marker
=
"<|tool▁calls▁begin|>"
tool_call_begin_marker
=
"<|tool▁call▁begin|>"
tool_sep_marker
=
"<|tool▁sep|>"
tool_call_end_marker
=
"<|tool▁call▁end|>"
tool_calls_end_marker
=
"<|tool▁calls▁end|>"
async
for
res
in
interface
.
inference
(
input_message
,
id
,
create
.
temperature
,
create
.
top_p
):
if
isinstance
(
res
,
RawUsage
):
if
isinstance
(
res
,
RawUsage
):
raw_usage
=
res
raw_usage
=
res
usage
=
CompletionUsage
(
usage
=
CompletionUsage
(
prompt_tokens
=
raw_usage
.
prefill_count
,
prompt_tokens
=
raw_usage
.
prefill_count
,
completion_tokens
=
raw_usage
.
decode_count
,
completion_tokens
=
raw_usage
.
decode_count
,
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
total_tokens
=
raw_usage
.
prefill_count
+
raw_usage
.
decode_count
)
)
el
se
:
el
if
isinstance
(
res
,
tuple
)
and
len
(
res
)
==
2
:
token
,
finish_reason
=
res
token
,
finish_reason
=
res
content
=
content
+
token
finish_reason
=
finish_reason
# Detecting the start of model-specific formatting tool calls
if
not
tool_call_mode
and
tool_calls_begin_marker
in
buffer
+
token
:
choice
=
Choice
(
tool_call_mode
=
True
index
=
0
,
finish_reason
=
finish_reason
,
# Adjust full_content to remove tool call section
message
=
ChatCompletionMessage
(
if
buffer
.
endswith
(
tool_calls_begin_marker
):
content
=
content
,
full_content
=
full_content
[:
-
len
(
tool_calls_begin_marker
)]
role
=
"assistant"
elif
tool_calls_begin_marker
in
(
buffer
+
token
):
))
idx
=
(
buffer
+
token
).
find
(
tool_calls_begin_marker
)
full_content
=
full_content
[:
-
(
len
(
buffer
)
-
idx
)]
chat_completion
=
ChatCompletion
(
buffer
=
""
id
=
id
,
choices
=
[
choice
],
# Accumulation of content in non-tool call mode
created
=
int
(
time
()),
if
not
tool_call_mode
:
model
=
Config
().
model_name
,
full_content
+=
token
object
=
'chat.completion'
,
buffer
+=
token
usage
=
usage
# Keep the buffer at a reasonable size
)
if
len
(
buffer
)
>
200
:
buffer
=
buffer
[
-
200
:]
return
chat_completion
else
:
# In tool call mode, continue to collect tool call related text
buffer
+=
token
# If the tool call end marker is found
if
tool_calls_end_marker
in
buffer
:
try
:
# Parsing Calling Text Extraction Tool Calling Information
full_tool_call
=
buffer
# Extract function name
function_name_start
=
full_tool_call
.
find
(
tool_sep_marker
)
+
len
(
tool_sep_marker
)
function_name_end
=
full_tool_call
.
find
(
"
\n
"
,
function_name_start
)
function_name
=
full_tool_call
[
function_name_start
:
function_name_end
].
strip
()
# Extract JSON Parameters - Extracts the content between ```json and ```.
json_pattern
=
r
'```json\s*(.*?)\s*```'
json_match
=
re
.
search
(
json_pattern
,
full_tool_call
,
re
.
DOTALL
)
if
json_match
:
arguments_str
=
json_match
.
group
(
1
).
strip
()
# Generate tool call IDs
tool_call_id
=
f
"call_
{
uuid4
().
hex
[:
24
]
}
"
# Add to tool call list
tool_calls
.
append
({
"id"
:
tool_call_id
,
"index"
:
0
,
"type"
:
"function"
,
"function"
:
{
"name"
:
function_name
,
"arguments"
:
arguments_str
}
})
# If the tool call is successfully parsed, set the reason for completion
finish_reason
=
"tool_calls"
# reset state
tool_call_mode
=
False
buffer
=
""
else
:
# JSON extraction failed, probably incomplete formatting
logger
.
warning
(
"Failed to extract JSON from tool call"
)
tool_call_mode
=
False
buffer
=
""
except
Exception
as
e
:
logger
.
error
(
f
"Error processing tool call:
{
e
}
"
)
tool_call_mode
=
False
buffer
=
""
# Build Response
response
=
{
"id"
:
id
,
"object"
:
"chat.completion"
,
"created"
:
int
(
time
()),
"model"
:
Config
().
model_name
,
"choices"
:
[{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
None
if
tool_calls
else
full_content
,
"tool_calls"
:
tool_calls
if
tool_calls
else
None
},
"finish_reason"
:
finish_reason
or
"stop"
}],
"usage"
:
usage
.
__dict__
,
"system_fingerprint"
:
f
"fp_
{
uuid4
().
hex
[:
12
]
}
"
}
return
response
ktransformers/server/backend/interfaces/ktransformers.py
View file @
a7e8d7c1
import
torch
import
torch
from
typing
import
Optional
,
List
import
asyncio
import
asyncio
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
ktransformers.server.backend.interfaces.transformers
import
(
from
ktransformers.server.backend.interfaces.transformers
import
(
...
@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
...
@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
tools
):
yield
v
yield
v
# return this inference raw usage
# return this inference raw usage
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
a7e8d7c1
from
typing
import
Any
,
List
,
Optional
,
Set
from
typing
import
Any
,
List
,
Optional
,
Set
import
re
import
json
import
uuid
from
transformers
import
(
from
transformers
import
(
LlamaTokenizer
,
LlamaTokenizer
,
AutoTokenizer
,
AutoTokenizer
,
...
@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
# Check if tools are present
has_tools
=
tools
is
not
None
and
len
(
tools
)
>
0
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
...
@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
)
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
...
@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield
think
,
None
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
None
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
# Handle tool calling
print
(
t
,
end
=
""
,
flush
=
True
)
if
has_tools
:
yield
t
,
finish_reason
# Start collecting tokens until we detect a tool call
print
(
""
)
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
collected_tokens
+=
t
# Check if we're starting a tool call
if
not
is_collecting_tool_call
and
any
(
keyword
in
collected_tokens
.
lower
()
for
keyword
in
[
'"function"'
,
'function'
,
'tool_call'
,
'tool call'
]):
is_collecting_tool_call
=
True
# Generate a unique tool call ID
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
.
replace
(
'-'
,
''
)
}
"
# Send first tool call info
if
len
(
tools
)
>
0
and
hasattr
(
tools
[
0
],
'function'
)
and
hasattr
(
tools
[
0
].
function
,
'name'
):
# If tools are provided, use the first one's name
recommended_function
=
tools
[
0
].
function
.
name
else
:
# Otherwise try to extract from context
function_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
recommended_function
=
function_match
.
group
(
1
)
if
function_match
else
""
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'index'
:
0
,
'function'
:
{
'name'
:
recommended_function
,
'arguments'
:
""
}
},
'first_chunk'
:
True
}
# Extract function name if we're collecting tool call
if
is_collecting_tool_call
and
not
is_function_name_collected
:
name_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
if
name_match
:
function_name
=
name_match
.
group
(
1
)
is_function_name_collected
=
True
# Track argument collection
if
is_collecting_tool_call
and
is_function_name_collected
:
args_position
=
collected_tokens
.
find
(
'"arguments"'
)
if
args_position
>
-
1
:
# Find the start of the JSON object after "arguments":
json_start
=
collected_tokens
.
find
(
'{'
,
args_position
)
if
json_start
>
-
1
:
for
i
in
range
(
json_start
,
len
(
collected_tokens
)):
char
=
collected_tokens
[
i
]
collected_arguments
+=
char
if
char
==
'{'
:
brackets_count
+=
1
elif
char
==
'}'
:
brackets_count
-=
1
# Check if we've completed the arguments JSON
if
brackets_count
==
0
:
# Send argument chunk
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'function'
:
{
'name'
:
function_name
,
'arguments'
:
collected_arguments
}
},
'argument_chunk'
:
collected_arguments
,
'last_chunk'
:
True
,
'prompt_tokens'
:
176
,
'completion_tokens'
:
20
}
# Reset for next potential tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
break
# Handle finish reason
if
finish_reason
is
not
None
:
yield
""
,
finish_reason
print
(
""
)
else
:
# Regular text generation (no tools)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
self
.
report_last_time_performance
()
ktransformers/server/schemas/endpoints/chat.py
View file @
a7e8d7c1
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
enum
import
Enum
from
enum
import
Enum
...
@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
...
@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
from
openai.types.chat.chat_completion_chunk
import
Choice
from
uuid
import
uuid4
from
pydantic
import
BaseModel
,
Field
class
Role
(
Enum
):
class
Role
(
Enum
):
system
=
'system'
system
=
'system'
...
@@ -17,26 +20,57 @@ class Role(Enum):
...
@@ -17,26 +20,57 @@ class Role(Enum):
tool
=
'tool'
tool
=
'tool'
function
=
'function'
function
=
'function'
class
Message
(
BaseModel
):
class
Message
(
BaseModel
):
content
:
str
content
:
Optional
[
str
]
=
None
role
:
Role
role
:
Role
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
tool_call_id
:
Optional
[
str
]
=
None
def
to_tokenizer_message
(
self
):
def
to_tokenizer_message
(
self
):
return
{
'content'
:
self
.
content
,
'role'
:
self
.
role
.
value
}
message
=
{
'role'
:
self
.
role
.
value
}
if
self
.
content
is
not
None
:
message
[
'content'
]
=
self
.
content
if
self
.
name
is
not
None
:
message
[
'name'
]
=
self
.
name
if
self
.
tool_calls
is
not
None
:
message
[
'tool_calls'
]
=
self
.
tool_calls
if
self
.
tool_call_id
is
not
None
:
message
[
'tool_call_id'
]
=
self
.
tool_call_id
return
message
class
FunctionParameters
(
BaseModel
):
type
:
str
=
"object"
properties
:
Dict
[
str
,
Any
]
=
{}
required
:
Optional
[
List
[
str
]]
=
None
class
FunctionDefinition
(
BaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
parameters
:
FunctionParameters
=
Field
(
default_factory
=
FunctionParameters
)
class
ToolFunction
(
BaseModel
):
function
:
FunctionDefinition
class
Tool
(
BaseModel
):
type
:
Literal
[
"function"
]
function
:
FunctionDefinition
class
ChatCompletionCreate
(
BaseModel
):
class
ChatCompletionCreate
(
BaseModel
):
messages
:
List
[
Message
]
messages
:
List
[
Message
]
model
:
str
model
:
str
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
temperature
:
Optional
[
float
]
=
Field
(
default
=
0.6
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
tools
:
Optional
[
List
[
Tool
]]
=
None
tool_choice
:
Optional
[
Union
[
str
,
Dict
[
str
,
Any
]]]
=
None
stream_options
:
Optional
[
Dict
[
str
,
Any
]]
=
None
frequency_penalty
:
float
=
0
presence_penalty
:
float
=
0
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
ChatCompletionChunk
(
BaseModel
):
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
id
:
str
choices
:
List
[
Choice
]
choices
:
List
[
Choice
]
...
@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
...
@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint
:
Optional
[
str
]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
def
to_stream_reply
(
self
):
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
tokenize_time
:
float
prefill_time
:
float
prefill_time
:
float
decode_time
:
float
decode_time
:
float
prefill_count
:
int
prefill_count
:
int
decode_count
:
int
decode_count
:
int
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment