"...lm-evaluation-harness.git" did not exist on "a2cada5d8e1534c636a6bd43ae98eb16eb6240a7"
Commit 29a3b823 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

refac: reuse stream_message_template

parent 22a5e196
......@@ -19,7 +19,7 @@ from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from utils.misc import stream_message_template, whole_message_template
from utils.task import prompt_template
......@@ -203,7 +203,7 @@ async def execute_pipe(pipe, params):
return pipe(**params)
async def get_message(res: str | Generator | AsyncGenerator) -> str:
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
......@@ -212,28 +212,6 @@ async def get_message(res: str | Generator | AsyncGenerator) -> str:
return "".join([str(stream) async for stream in res])
def get_final_message(form_data: dict, message: str | None = None) -> dict:
choice = {
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
# If message is None, we're dealing with a chunk
if not message:
choice["delta"] = {}
else:
choice["message"] = {"role": "assistant", "content": message}
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"created": int(time.time()),
"model": form_data["model"],
"object": "chat.completion" if message is not None else "chat.completion.chunk",
"choices": [choice],
}
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
......@@ -292,7 +270,9 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module):
def get_extra_params(metadata: dict):
__event_emitter__ = __event_call__ = __task__ = None
__event_emitter__ = None
__event_call__ = None
__task__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
......@@ -401,7 +381,8 @@ async def generate_function_chat_completion(form_data, user):
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = get_final_message(form_data)
finish_message = stream_message_template(form_data, "")
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
......@@ -419,7 +400,7 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message(res)
return get_final_message(form_data, message)
message = await get_message_content(res)
return whole_message_template(form_data["model"], message)
return await job()
......@@ -87,23 +87,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
def stream_message_template(model: str, message: str):
def message_template(model: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
"choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
}
def stream_message_template(model: str, message: str):
template = message_template(model)
template["object"] = "chat.completion.chunk"
template["choices"][0]["delta"] = {"content": message}
return template
def whole_message_template(model: str, message: str):
template = message_template(model)
template["object"] = "chat.completion"
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
template["choices"][0]["finish_reason"] = "stop"
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment