Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
bd5a8567
Commit
bd5a8567
authored
Jun 11, 2024
by
Timothy J. Baek
Browse files
refac: tools & rag
parent
fc465329
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
34 deletions
+26
-34
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+2
-12
backend/main.py
backend/main.py
+24
-22
No files found.
backend/apps/rag/utils.py
View file @
bd5a8567
...
@@ -236,10 +236,9 @@ def get_embedding_function(
...
@@ -236,10 +236,9 @@ def get_embedding_function(
return
lambda
query
:
generate_multiple
(
query
,
func
)
return
lambda
query
:
generate_multiple
(
query
,
func
)
def
rag_messages
(
def
get_rag_context
(
docs
,
docs
,
messages
,
messages
,
template
,
embedding_function
,
embedding_function
,
k
,
k
,
reranking_function
,
reranking_function
,
...
@@ -318,16 +317,7 @@ def rag_messages(
...
@@ -318,16 +317,7 @@ def rag_messages(
context_string
=
context_string
.
strip
()
context_string
=
context_string
.
strip
()
ra_content
=
rag_template
(
return
context_string
,
citations
template
=
template
,
context
=
context_string
,
query
=
query
,
)
log
.
debug
(
f
"ra_content:
{
ra_content
}
"
)
messages
=
add_or_update_system_message
(
ra_content
,
messages
)
return
messages
,
citations
def
get_model_path
(
model
:
str
,
update_model
:
bool
=
False
):
def
get_model_path
(
model
:
str
,
update_model
:
bool
=
False
):
...
...
backend/main.py
View file @
bd5a8567
...
@@ -64,7 +64,7 @@ from utils.task import (
...
@@ -64,7 +64,7 @@ from utils.task import (
)
)
from
utils.misc
import
get_last_user_message
,
add_or_update_system_message
from
utils.misc
import
get_last_user_message
,
add_or_update_system_message
from
apps.rag.utils
import
rag_messages
,
rag_template
from
apps.rag.utils
import
get_rag_context
,
rag_template
from
config
import
(
from
config
import
(
CONFIG_DATA
,
CONFIG_DATA
,
...
@@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
user
=
get_current_user
(
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
))
)
# Remove the citations from the body
# Remove the citations from the body
return_citations
=
data
.
get
(
"citations"
,
False
)
return_citations
=
data
.
get
(
"citations"
,
False
)
if
"citations"
in
data
:
if
"citations"
in
data
:
...
@@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
):
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
"tool_ids"
in
data
:
user
=
get_current_user
(
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
))
)
prompt
=
get_last_user_message
(
data
[
"messages"
])
context
=
""
context
=
""
# If tool_ids field is present, call the functions
if
"tool_ids"
in
data
:
prompt
=
get_last_user_message
(
data
[
"messages"
])
for
tool_id
in
data
[
"tool_ids"
]:
for
tool_id
in
data
[
"tool_ids"
]:
print
(
tool_id
)
print
(
tool_id
)
response
=
await
get_function_call_response
(
response
=
await
get_function_call_response
(
...
@@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
response
:
if
response
:
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
if
context
!=
""
:
system_prompt
=
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context
,
prompt
)
print
(
system_prompt
)
data
[
"messages"
]
=
add_or_update_system_message
(
f
"
\n
{
system_prompt
}
"
,
data
[
"messages"
]
)
del
data
[
"tool_ids"
]
del
data
[
"tool_ids"
]
# If docs field is present, generate RAG completions
# If docs field is present, generate RAG completions
if
"docs"
in
data
:
if
"docs"
in
data
:
data
=
{
**
data
}
data
=
{
**
data
}
data
[
"messages"
],
citations
=
rag_messages
(
rag_context
,
citations
=
get_rag_context
(
docs
=
data
[
"docs"
],
docs
=
data
[
"docs"
],
messages
=
data
[
"messages"
],
messages
=
data
[
"messages"
],
template
=
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
)
if
rag_context
:
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
del
data
[
"docs"
]
del
data
[
"docs"
]
log
.
debug
(
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
f
"data['messages']:
{
data
[
'messages'
]
}
, citations:
{
citations
}
"
if
context
!=
""
:
system_prompt
=
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context
,
prompt
)
print
(
system_prompt
)
data
[
"messages"
]
=
add_or_update_system_message
(
f
"
\n
{
system_prompt
}
"
,
data
[
"messages"
]
)
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment