Unverified Commit b0551323 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1154 from 344303947/features/add-function-calling

Fix the error caused by the client not passing temperature and top_p being empty
parents 3efb6621 c8db24d5
...@@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
# thread_related # thread_related
last_request_id: Optional[str] = None last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set() ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
self.queue_map:dict[int,asyncio.Queue] = {} self.queue_map:dict[int,asyncio.Queue] = {}
...@@ -283,6 +284,20 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -283,6 +284,20 @@ class BalanceServeInterface(BackendInterfaceBase):
processes.append(p) processes.append(p)
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if temperature is None:
temperature = Config().temperature
if top_p is None:
top_p = Config().top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
def run_queue_proxy(self): def run_queue_proxy(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
...@@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
...@@ -352,13 +366,10 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -352,13 +366,10 @@ class BalanceServeInterface(BackendInterfaceBase):
[input_ids, token_thinks], dim=1 [input_ids, token_thinks], dim=1
) )
profiler.pause_timer("tokenize") profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill") profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd() query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist() query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0] query_length = input_ids[0].shape[0]
...@@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
#@TODO add server #@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria query_add.stop_criteria = stop_criteria
if temperature == 0:
temperature = 0.0001 temperature, top_p = self.get_sampling_params(temperature, top_p)
query_add.sample_options.temperature = temperature query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment