Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
d2cf8142
Unverified
Commit
d2cf8142
authored
Apr 16, 2025
by
Chengyu Qiu
Committed by
GitHub
Apr 16, 2025
Browse files
Merge pull request #1135 from Creeper-MZ/function_call
Feat: Add Function call support
parents
fcbd41e1
a7e8d7c1
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
554 additions
and
89 deletions
+554
-89
ktransformers/server/api/openai/endpoints/chat.py
ktransformers/server/api/openai/endpoints/chat.py
+392
-65
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+3
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+115
-10
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+44
-12
No files found.
ktransformers/server/api/openai/endpoints/chat.py
View file @
d2cf8142
This diff is collapsed.
Click to expand it.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
d2cf8142
import
torch
import
torch
from
typing
import
Optional
,
List
import
asyncio
import
asyncio
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
,
GenerationConfig
from
ktransformers.server.backend.interfaces.transformers
import
(
from
ktransformers.server.backend.interfaces.transformers
import
(
...
@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
...
@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
,
tools
):
yield
v
yield
v
# return this inference raw usage
# return this inference raw usage
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
d2cf8142
from
typing
import
Any
,
List
,
Optional
,
Set
from
typing
import
Any
,
List
,
Optional
,
Set
import
re
import
json
import
uuid
from
transformers
import
(
from
transformers
import
(
LlamaTokenizer
,
LlamaTokenizer
,
AutoTokenizer
,
AutoTokenizer
,
...
@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
tools
:
Optional
[
List
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
# Check if tools are present
has_tools
=
tools
is
not
None
and
len
(
tools
)
>
0
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
...
@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
)
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
pause_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
if
Config
().
user_force_think
:
...
@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield
think
,
None
yield
think
,
None
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
),
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
None
yield
t
,
None
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
# Handle tool calling
print
(
t
,
end
=
""
,
flush
=
True
)
if
has_tools
:
yield
t
,
finish_reason
# Start collecting tokens until we detect a tool call
print
(
""
)
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
collected_tokens
+=
t
# Check if we're starting a tool call
if
not
is_collecting_tool_call
and
any
(
keyword
in
collected_tokens
.
lower
()
for
keyword
in
[
'"function"'
,
'function'
,
'tool_call'
,
'tool call'
]):
is_collecting_tool_call
=
True
# Generate a unique tool call ID
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
.
replace
(
'-'
,
''
)
}
"
# Send first tool call info
if
len
(
tools
)
>
0
and
hasattr
(
tools
[
0
],
'function'
)
and
hasattr
(
tools
[
0
].
function
,
'name'
):
# If tools are provided, use the first one's name
recommended_function
=
tools
[
0
].
function
.
name
else
:
# Otherwise try to extract from context
function_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
recommended_function
=
function_match
.
group
(
1
)
if
function_match
else
""
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'index'
:
0
,
'function'
:
{
'name'
:
recommended_function
,
'arguments'
:
""
}
},
'first_chunk'
:
True
}
# Extract function name if we're collecting tool call
if
is_collecting_tool_call
and
not
is_function_name_collected
:
name_match
=
re
.
search
(
r
'"name":\s*"([^"]+)"'
,
collected_tokens
)
if
name_match
:
function_name
=
name_match
.
group
(
1
)
is_function_name_collected
=
True
# Track argument collection
if
is_collecting_tool_call
and
is_function_name_collected
:
args_position
=
collected_tokens
.
find
(
'"arguments"'
)
if
args_position
>
-
1
:
# Find the start of the JSON object after "arguments":
json_start
=
collected_tokens
.
find
(
'{'
,
args_position
)
if
json_start
>
-
1
:
for
i
in
range
(
json_start
,
len
(
collected_tokens
)):
char
=
collected_tokens
[
i
]
collected_arguments
+=
char
if
char
==
'{'
:
brackets_count
+=
1
elif
char
==
'}'
:
brackets_count
-=
1
# Check if we've completed the arguments JSON
if
brackets_count
==
0
:
# Send argument chunk
yield
{
'tool_call'
:
{
'id'
:
tool_call_id
,
'type'
:
'function'
,
'function'
:
{
'name'
:
function_name
,
'arguments'
:
collected_arguments
}
},
'argument_chunk'
:
collected_arguments
,
'last_chunk'
:
True
,
'prompt_tokens'
:
176
,
'completion_tokens'
:
20
}
# Reset for next potential tool call
collected_tokens
=
""
is_collecting_tool_call
=
False
is_function_name_collected
=
False
function_name
=
""
collected_arguments
=
""
brackets_count
=
0
break
# Handle finish reason
if
finish_reason
is
not
None
:
yield
""
,
finish_reason
print
(
""
)
else
:
# Regular text generation (no tools)
for
t
,
finish_reason
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
,
finish_reason
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
report_last_time_performance
()
self
.
report_last_time_performance
()
ktransformers/server/schemas/endpoints/chat.py
View file @
d2cf8142
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
enum
import
Enum
from
enum
import
Enum
...
@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
...
@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.completion_usage
import
CompletionUsage
from
openai.types.chat.chat_completion_chunk
import
Choice
from
openai.types.chat.chat_completion_chunk
import
Choice
from
uuid
import
uuid4
from
pydantic
import
BaseModel
,
Field
class
Role
(
Enum
):
class
Role
(
Enum
):
system
=
'system'
system
=
'system'
...
@@ -17,26 +20,57 @@ class Role(Enum):
...
@@ -17,26 +20,57 @@ class Role(Enum):
tool
=
'tool'
tool
=
'tool'
function
=
'function'
function
=
'function'
class
Message
(
BaseModel
):
class
Message
(
BaseModel
):
content
:
str
content
:
Optional
[
str
]
=
None
role
:
Role
role
:
Role
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
tool_call_id
:
Optional
[
str
]
=
None
def
to_tokenizer_message
(
self
):
def
to_tokenizer_message
(
self
):
return
{
'content'
:
self
.
content
,
'role'
:
self
.
role
.
value
}
message
=
{
'role'
:
self
.
role
.
value
}
if
self
.
content
is
not
None
:
message
[
'content'
]
=
self
.
content
if
self
.
name
is
not
None
:
message
[
'name'
]
=
self
.
name
if
self
.
tool_calls
is
not
None
:
message
[
'tool_calls'
]
=
self
.
tool_calls
if
self
.
tool_call_id
is
not
None
:
message
[
'tool_call_id'
]
=
self
.
tool_call_id
return
message
class
FunctionParameters
(
BaseModel
):
type
:
str
=
"object"
properties
:
Dict
[
str
,
Any
]
=
{}
required
:
Optional
[
List
[
str
]]
=
None
class
FunctionDefinition
(
BaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
parameters
:
FunctionParameters
=
Field
(
default_factory
=
FunctionParameters
)
class
ToolFunction
(
BaseModel
):
function
:
FunctionDefinition
class
Tool
(
BaseModel
):
type
:
Literal
[
"function"
]
function
:
FunctionDefinition
class
ChatCompletionCreate
(
BaseModel
):
class
ChatCompletionCreate
(
BaseModel
):
messages
:
List
[
Message
]
messages
:
List
[
Message
]
model
:
str
model
:
str
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
temperature
:
Optional
[
float
]
=
Field
(
default
=
0.6
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
top_p
:
Optional
[
float
]
=
Field
(
default
=
1.0
)
tools
:
Optional
[
List
[
Tool
]]
=
None
tool_choice
:
Optional
[
Union
[
str
,
Dict
[
str
,
Any
]]]
=
None
stream_options
:
Optional
[
Dict
[
str
,
Any
]]
=
None
frequency_penalty
:
float
=
0
presence_penalty
:
float
=
0
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
class
ChatCompletionChunk
(
BaseModel
):
class
ChatCompletionChunk
(
BaseModel
):
id
:
str
id
:
str
choices
:
List
[
Choice
]
choices
:
List
[
Choice
]
...
@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
...
@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint
:
Optional
[
str
]
=
None
system_fingerprint
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
def
to_stream_reply
(
self
):
def
to_stream_reply
(
self
):
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
return
f
"data:
{
self
.
model_dump_json
()
}
\n\n
"
class
RawUsage
(
BaseModel
):
class
RawUsage
(
BaseModel
):
tokenize_time
:
float
tokenize_time
:
float
prefill_time
:
float
prefill_time
:
float
decode_time
:
float
decode_time
:
float
prefill_count
:
int
prefill_count
:
int
decode_count
:
int
decode_count
:
int
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment