Commit 006fc349 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

fix: stream_message_template

parent 29a3b823
...@@ -287,37 +287,42 @@ def get_extra_params(metadata: dict): ...@@ -287,37 +287,42 @@ def get_extra_params(metadata: dict):
} }
async def generate_function_chat_completion(form_data, user): def add_model_params(params: dict, form_data: dict) -> dict:
print("entry point") if not params:
model_id = form_data.get("model") return form_data
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
mappings = { mappings = {
"temperature": float, "temperature": float,
"top_p": int, "top_p": int,
"max_tokens": int, "max_tokens": int,
"frequency_penalty": int, "frequency_penalty": int,
"seed": lambda x: x, "seed": lambda x: x,
"stop": lambda x: [ "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
bytes(s, "utf-8").decode("unicode_escape") for s in x
],
} }
for key, cast_func in mappings.items(): for key, cast_func in mappings.items():
if (value := params.get(key)) is not None: if (value := params.get(key)) is not None:
form_data[key] = cast_func(value) form_data[key] = cast_func(value)
return form_data
async def generate_function_chat_completion(form_data, user):
print("entry point")
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
system = params.get("system", None) system = params.get("system", None)
form_data = add_model_params(params, form_data)
if system: if system:
if user: if user:
template_params = { template_params = {
...@@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user): ...@@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user):
yield process_line(form_data, line) yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = stream_message_template(form_data, "") finish_message = stream_message_template(form_data["model"], "")
finish_message["choices"][0]["finish_reason"] = "stop" finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]" yield "data: [DONE]"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment