Commit 1aaa2e82 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

fix: ollama rag issue workaround

parent e6918397
...@@ -79,6 +79,7 @@ from utils.task import ( ...@@ -79,6 +79,7 @@ from utils.task import (
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
prepend_to_first_user_message_content,
parse_duration, parse_duration,
) )
...@@ -686,6 +687,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -686,6 +687,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(contexts) > 0: if len(contexts) > 0:
context_string = "/n".join(contexts).strip() context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"]) prompt = get_last_user_message(body["messages"])
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message( body["messages"] = add_or_update_system_message(
rag_template( rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt rag_app.state.config.RAG_TEMPLATE, context_string, prompt
......
...@@ -53,6 +53,21 @@ def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: ...@@ -53,6 +53,21 @@ def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages) return get_system_message(messages), remove_system_message(messages)
def prepend_to_first_user_message_content(
content: str, messages: List[dict]
) -> List[dict]:
for message in messages:
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
item["text"] = f"{content}\n{item['text']}"
else:
message["content"] = f"{content}\n{message['content']}"
break
return messages
def add_or_update_system_message(content: str, messages: List[dict]): def add_or_update_system_message(content: str, messages: List[dict]):
""" """
Adds a new system message at the beginning of the messages list Adds a new system message at the beginning of the messages list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment