Unverified Commit d2cf8142 authored by Chengyu Qiu's avatar Chengyu Qiu Committed by GitHub
Browse files

Merge pull request #1135 from Creeper-MZ/function_call

Feat: Add Function call support
parents fcbd41e1 a7e8d7c1
import torch import torch
from typing import Optional, List
import asyncio import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import ( from ktransformers.server.backend.interfaces.transformers import (
...@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface): ...@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p): async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
yield v yield v
# return this inference raw usage # return this inference raw usage
......
from typing import Any, List, Optional, Set from typing import Any, List, Optional, Set
import re
import json
import uuid
from transformers import ( from transformers import (
LlamaTokenizer, LlamaTokenizer,
AutoTokenizer, AutoTokenizer,
...@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
...@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
) )
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
...@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield think, None yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t, None yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None: # Handle tool calling
print(t, end="",flush=True) if has_tools:
yield t, finish_reason # Start collecting tokens until we detect a tool call
print("") collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()
from typing import List, Optional from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
...@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object ...@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4
from pydantic import BaseModel, Field
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
...@@ -17,26 +20,57 @@ class Role(Enum): ...@@ -17,26 +20,57 @@ class Role(Enum):
tool = 'tool' tool = 'tool'
function = 'function' function = 'function'
class Message(BaseModel): class Message(BaseModel):
content: str content: Optional[str] = None
role:Role role: Role
name: Optional[str] = None name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
def to_tokenizer_message(self): def to_tokenizer_message(self):
return {'content':self.content,'role':self.role.value} message = {'role': self.role.value}
if self.content is not None:
message['content'] = self.content
if self.name is not None:
message['name'] = self.name
if self.tool_calls is not None:
message['tool_calls'] = self.tool_calls
if self.tool_call_id is not None:
message['tool_call_id'] = self.tool_call_id
return message
class FunctionParameters(BaseModel):
type: str = "object"
properties: Dict[str, Any] = {}
required: Optional[List[str]] = None
class FunctionDefinition(BaseModel):
name: str
description: Optional[str] = None
parameters: FunctionParameters = Field(default_factory=FunctionParameters)
class ToolFunction(BaseModel):
function: FunctionDefinition
class Tool(BaseModel):
type: Literal["function"]
function: FunctionDefinition
class ChatCompletionCreate(BaseModel): class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model : str model: str
stream : bool = False stream: bool = False
temperature: Optional[float] = Field(default=1.0) temperature: Optional[float] = Field(default=0.6)
top_p: Optional[float] = Field(default=1.0) top_p: Optional[float] = Field(default=1.0)
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0
presence_penalty: float = 0
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
class ChatCompletionChunk(BaseModel): class ChatCompletionChunk(BaseModel):
id: str id: str
choices: List[Choice] choices: List[Choice]
...@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel): ...@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None usage: Optional[CompletionUsage] = None
def to_stream_reply(self): def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n" return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel): class RawUsage(BaseModel):
tokenize_time: float tokenize_time: float
prefill_time: float prefill_time: float
decode_time: float decode_time: float
prefill_count: int prefill_count: int
decode_count: int decode_count: int
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment