Commit deec41d2 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

fix: function early returns

parent 3978efd7
...@@ -291,12 +291,7 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): ...@@ -291,12 +291,7 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module):
return params return params
async def generate_function_chat_completion(form_data, user): def get_extra_params(metadata: dict):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
__event_emitter__ = __event_call__ = __task__ = None __event_emitter__ = __event_call__ = __task__ = None
if metadata: if metadata:
...@@ -305,57 +300,66 @@ async def generate_function_chat_completion(form_data, user): ...@@ -305,57 +300,66 @@ async def generate_function_chat_completion(form_data, user):
__event_call__ = get_event_call(metadata) __event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None) __task__ = metadata.get("task", None)
if not model_info: return {
return "__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
if model_info.base_model_id: "__task__": __task__,
form_data["model"] = model_info.base_model_id }
params = model_info.params.model_dump()
if params:
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
system = params.get("system", None)
if not system:
return
if user: async def generate_function_chat_completion(form_data, user):
template_params = { print("entry point")
"user_name": user.name, model_id = form_data.get("model")
"user_location": user.info.get("location") if user.info else None, model_info = Models.get_model_by_id(model_id)
}
else:
template_params = {}
system = prompt_template(system, **template_params) metadata = form_data.pop("metadata", None)
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [
bytes(s, "utf-8").decode("unicode_escape") for s in x
],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
system = params.get("system", None)
if system:
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
# Check if the payload already has a system message system = prompt_template(system, **template_params)
# If not, add a system message to the payload
for message in form_data.get("messages", []):
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(0, {"role": "system", "content": system})
extra_params = { # Check if the payload already has a system message
"__event_emitter__": __event_emitter__, # If not, add a system message to the payload
"__event_call__": __event_call__, for message in form_data.get("messages", []):
"__task__": __task__, if message.get("role") == "system":
} message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(
0, {"role": "system", "content": system}
)
async def job(): async def job():
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment