Unverified Commit 48558801 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1177 from kvcache-ai/update_param

Update param
parents a1162eea f5287e90
...@@ -12,7 +12,10 @@ import torch.nn as nn ...@@ -12,7 +12,10 @@ import torch.nn as nn
import transformers import transformers
from transformers import Cache, PretrainedConfig from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
class StaticCache(transformers.StaticCache): class StaticCache(transformers.StaticCache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)`.
...@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module): ...@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = [] self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext): def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
self.k_caches.append( self.k_caches.append(
......
...@@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -207,7 +207,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tools▁end>":"<|tool▁calls▁end|>" "<tools▁end>":"<|tool▁calls▁end|>"
} }
# Use check_client_connected for early stopping # Use check_client_connected for early stopping
async for res in interface.inference(input_message, id, create.temperature, create.top_p): async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
# Final return on utilization # Final return on utilization
raw_usage = res raw_usage = res
...@@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -371,7 +371,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"<tool▁end>":"<|tool▁call▁end|>", "<tool▁end>":"<|tool▁call▁end|>",
"<tools▁end>":"<|tool▁calls▁end|>" "<tools▁end>":"<|tool▁calls▁end|>"
} }
async for res in interface.inference(input_message, id, create.temperature, create.top_p): async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
usage = CompletionUsage( usage = CompletionUsage(
......
...@@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage ...@@ -11,7 +11,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request,create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
interface = get_interface() interface = get_interface()
...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
if create.stream: if create.stream:
async def inner(): async def inner():
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
else: else:
...@@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -32,7 +32,7 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
else: else:
......
...@@ -80,6 +80,7 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_ ...@@ -80,6 +80,7 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates[i].generated_token = generated_tokens[i].item() query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill: if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position pos = query_updates[i].active_position
if pos < query_manager.query_map[query_updates[i].id].max_length:
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler): def report_last_time_performance(profiler: Profiler):
...@@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases""" """Get sampling parameters and handle default values and edge cases"""
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_completion_tokens = self.args.max_new_tokens
else:
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if temperature is None: if temperature is None:
temperature = Config().temperature temperature = self.args.temperature
if top_p is None: if top_p is None:
top_p = Config().top_p top_p = self.args.top_p
if temperature == 0: if temperature == 0:
temperature = 0.0001 temperature = 0.0001
if top_p == 0: if top_p == 0:
top_p = 0.0001 top_p = 0.0001
return temperature, top_p return temperature, top_p, max_completion_tokens
def run_queue_proxy(self): def run_queue_proxy(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
...@@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger.debug(f"get input ids of shape {input_ids.shape}") logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
profiler = Profiler() profiler = Profiler()
profiler.create_and_start_timer("tokenize") profiler.create_and_start_timer("tokenize")
...@@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria query_add.stop_criteria = stop_criteria
temperature, top_p = self.get_sampling_params(temperature, top_p) temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
query_add.sample_options.temperature = temperature query_add.sample_options.temperature = temperature
query_add.sample_options.top_p = top_p query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
if query_add.estimated_length < query_add.query_length: if query_add.estimated_length < query_add.query_length:
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}') raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
query_id = self.sched_client.add_query(query_add) query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens) queue = asyncio.Queue(maxsize=max_new_tokens)
self.queue_map[query_id] = queue self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id self.thread_map[thread_id] = query_id
is_first_token = True is_first_token = True
...@@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.pause_timer("decode") profiler.pause_timer("decode")
report_last_time_performance(profiler) report_last_time_performance(profiler)
yield self.streamer.end(), None yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1: if profiler.get_counter('decode') >= max_new_tokens - 1:
yield "", "length" yield "", "length"
else: else:
yield "", "stop" yield "", "stop"
......
...@@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface): ...@@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if(input_ids_length >= self.args.cache_lens): if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length self.seq_length = input_ids_length
...@@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
if getattr(self, 'generated_ids', None) is None: if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1, input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
...@@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens) expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1] delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0: if delta_length > 0:
new_generate_ids = torch.zeros( new_generate_ids = torch.zeros(
...@@ -222,6 +228,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -222,6 +228,7 @@ class KTransformersInterface(TransformersInterface):
MLAWrapperSingleton.reset_buffer() MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@property @property
...@@ -229,9 +236,9 @@ class KTransformersInterface(TransformersInterface): ...@@ -229,9 +236,9 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools): async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
yield v yield v
# return this inference raw usage # return this inference raw usage
......
...@@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -262,10 +262,15 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits) return self.logits_to_token(logits)
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_new_tokens = self.args.max_new_tokens
else:
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if is_new: if is_new:
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
same_prefix = 0 same_prefix = 0
...@@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -274,7 +279,7 @@ class TransformersInterface(BackendInterfaceBase):
if getattr(self, 'generated_ids', None) is None: if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1, input_ids.shape[-1] + max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
...@@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -301,7 +306,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.debug(f"generate_ids: {self.generated_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1 expected_length = self.seq_length + max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1] delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0: if delta_length > 0:
new_generate_ids = torch.zeros( new_generate_ids = torch.zeros(
...@@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -330,17 +335,16 @@ class TransformersInterface(BackendInterfaceBase):
self.prepare_logits_wrapper(input_ids, device, temperature, top_p) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@torch.no_grad @torch.no_grad
def generate(self): def generate(self):
self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0): if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end(), "length" yield self.streamer.end(), "length"
return return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
for i in range(1, self.max_new_tokens): for i in range(1, self.max_new_tokens):
...@@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -378,17 +382,15 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
...@@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -399,6 +401,7 @@ class TransformersInterface(BackendInterfaceBase):
) )
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
...@@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -406,119 +409,18 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True) print(think, end="",flush=True)
yield think, None yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):
# output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t, None yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
# Handle tool calling
if has_tools:
# Start collecting tokens until we detect a tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate(): for t, finish_reason in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t, finish_reason yield t, finish_reason
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
...@@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice ...@@ -11,7 +10,6 @@ from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, Field
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
...@@ -67,6 +65,8 @@ class ChatCompletionCreate(BaseModel): ...@@ -67,6 +65,8 @@ class ChatCompletionCreate(BaseModel):
stream_options: Optional[Dict[str, Any]] = None stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0 frequency_penalty: float = 0
presence_penalty: float = 0 presence_penalty: float = 0
max_tokens: Optional[int] = Field(default=50)
max_completion_tokens: Optional[int] = Field(default=50)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
......
from typing import List, Optional from typing import List, Optional
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field
from pydantic import BaseModel
from ..base import Object from ..base import Object
...@@ -9,8 +8,10 @@ class CompletionCreate(BaseModel): ...@@ -9,8 +8,10 @@ class CompletionCreate(BaseModel):
model: str model: str
prompt: str | List[str] prompt: str | List[str]
stream: bool = False stream: bool = False
temperature: Optional[float] = None temperature: Optional[float] = Field(default=0.6)
top_p: Optional[float] = None top_p: Optional[float] = Field(default=1)
max_tokens: Optional[int] = Field(default=50)
max_completion_tokens: Optional[int] = Field(default=50)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):
......
from openai import OpenAI
def send_messages(messages):
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
tools=tools
)
return response.choices[0].message
client = OpenAI(
api_key="placeholder",
base_url="http://0.0.0.0:10002/v1",
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather of an location, the user shoud supply a location first",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"]
},
}
},
]
messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}]
message = send_messages(messages)
print(f"User>\t {messages[0]['content']}")
print(message)
tool = message.tool_calls[0]
messages.append(message)
messages.append({"role": "tool", "tool_call_id": tool.id, "content": "24℃"})
message = send_messages(messages)
print(f"Model>\t {message.content}")
\ No newline at end of file
...@@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions" ...@@ -15,18 +15,9 @@ SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1] bf_list = [1]
decodesz_list = [128] decodesz_list = [128]
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke '] prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
async def fetch_event_stream(session, request_id): async def fetch_event_stream(session, payload, request_id):
try: try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = { headers = {
'accept': 'application/json', 'accept': 'application/json',
...@@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id): ...@@ -103,7 +94,35 @@ async def fetch_event_stream(session, request_id):
async def main(prompt_id): async def main(prompt_id):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)] payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[prompt_id]}
],
"model": "DeepSeek-V3",
"stream": True,
"max_completion_tokens": 2,
# "temperature": 0.3,
# "top_p": 1.0,
# "max_tokens" : 20,
}
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["temperature"] = 0.3
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["top_p"] = 1
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["max_tokens"] = 200
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["stream"] = False
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { ...@@ -3326,7 +3326,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
default: default:
{ {
printf("case:%d",typeA); // printf("case:%d",typeA);
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment