Commit 86991091 authored by sean.su's avatar sean.su
Browse files

Refactor the chat interface to support tool calling and parameter processing

Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling.

Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations.

Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases.

Updated ktransformers.py and transformers.py to support the passing of tool parameters.

Modified the default value of top_p in config.py to 1.0 to increase generation diversity.

Extended the message model in chat.py to support the transmission of tool call information.

These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
parent 038db30e
......@@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.queue_map:dict[int,asyncio.Queue] = {}
......@@ -282,7 +283,21 @@ class BalanceServeInterface(BackendInterfaceBase):
p.start()
processes.append(p)
start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if temperature is None:
temperature = Config().temperature
if top_p is None:
top_p = Config().top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
......@@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
......@@ -352,12 +366,9 @@ class BalanceServeInterface(BackendInterfaceBase):
[input_ids, token_thinks], dim=1
)
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
......@@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
if temperature == 0:
temperature = 0.0001
temperature, top_p = self.get_sampling_params(temperature, top_p)
query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
......
import torch
from typing import Optional, List
import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import (
......@@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p):
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
yield v
# return this inference raw usage
......
from typing import Any, List, Optional, Set
import re
import json
import uuid
from transformers import (
LlamaTokenizer,
AutoTokenizer,
......@@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
......@@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
)
self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
......@@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
# Handle tool calling
if has_tools:
# Start collecting tokens until we detect a tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()
......@@ -133,7 +133,7 @@ class Config(metaclass=Singleton):
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
self.top_k = self.model.get("top_k", 50)
self.top_p = self.model.get("top_p", 0.8)
self.top_p = self.model.get("top_p", 1.0)
self.top_a = self.model.get("top_a", 0.0)
self.skew = self.model.get("skew", 0.0)
self.typical = self.model.get("typical", 0.0)
......
from typing import List, Optional
from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal
from enum import Enum
......@@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4
from pydantic import BaseModel, Field
class Role(Enum):
system = 'system'
......@@ -17,26 +20,57 @@ class Role(Enum):
tool = 'tool'
function = 'function'
class Message(BaseModel):
content: str
role:Role
content: Optional[str] = None
role: Role
name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
def to_tokenizer_message(self):
return {'content':self.content,'role':self.role.value}
message = {'role': self.role.value}
if self.content is not None:
message['content'] = self.content
if self.name is not None:
message['name'] = self.name
if self.tool_calls is not None:
message['tool_calls'] = self.tool_calls
if self.tool_call_id is not None:
message['tool_call_id'] = self.tool_call_id
return message
class FunctionParameters(BaseModel):
type: str = "object"
properties: Dict[str, Any] = {}
required: Optional[List[str]] = None
class FunctionDefinition(BaseModel):
name: str
description: Optional[str] = None
parameters: FunctionParameters = Field(default_factory=FunctionParameters)
class ToolFunction(BaseModel):
function: FunctionDefinition
class Tool(BaseModel):
type: Literal["function"]
function: FunctionDefinition
class ChatCompletionCreate(BaseModel):
messages: List[Message]
model : str
stream : bool = False
temperature: Optional[float] = Field(default=1.0)
model: str
stream: bool = False
temperature: Optional[float] = Field(default=0.6)
top_p: Optional[float] = Field(default=1.0)
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0
presence_penalty: float = 0
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
class ChatCompletionChunk(BaseModel):
id: str
choices: List[Choice]
......@@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int
decode_count: int
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment