Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
8df6b137
Commit
8df6b137
authored
Mar 10, 2024
by
Timothy J. Baek
Browse files
fix: rag
parent
88d324b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
99 deletions
+121
-99
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+86
-0
backend/main.py
backend/main.py
+35
-99
No files found.
backend/apps/rag/utils.py
View file @
8df6b137
...
...
@@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str):
template
=
re
.
sub
(
r
"\[query\]"
,
query
,
template
)
return
template
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
print
(
docs
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
if
messages
[
i
][
"role"
]
==
"user"
:
last_user_message_idx
=
i
break
user_message
=
messages
[
last_user_message_idx
]
if
isinstance
(
user_message
[
"content"
],
list
):
# Handle list content input
content_type
=
"list"
query
=
""
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
query
=
content_item
[
"text"
]
break
elif
isinstance
(
user_message
[
"content"
],
str
):
# Handle text content input
content_type
=
"text"
query
=
user_message
[
"content"
]
else
:
# Fallback in case the input does not match expected types
content_type
=
None
query
=
""
relevant_contexts
=
[]
for
doc
in
docs
:
context
=
None
try
:
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
k
,
embedding_function
=
embedding_function
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
k
,
embedding_function
=
embedding_function
,
)
except
Exception
as
e
:
print
(
e
)
context
=
None
relevant_contexts
.
append
(
context
)
context_string
=
""
for
context
in
relevant_contexts
:
if
context
:
context_string
+=
" "
.
join
(
context
[
"documents"
][
0
])
+
"
\n
"
ra_content
=
rag_template
(
template
=
template
,
context
=
context_string
,
query
=
query
,
)
if
content_type
==
"list"
:
new_content
=
[]
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
# Update the text item's content with ra_content
new_content
.
append
({
"type"
:
"text"
,
"text"
:
ra_content
})
else
:
# Keep other types of content as they are
new_content
.
append
(
content_item
)
new_user_message
=
{
**
user_message
,
"content"
:
new_content
}
else
:
new_user_message
=
{
**
user_message
,
"content"
:
ra_content
,
}
messages
[
last_user_message_idx
]
=
new_user_message
return
messages
backend/main.py
View file @
8df6b137
...
...
@@ -28,7 +28,7 @@ from typing import List
from
utils.utils
import
get_admin_user
from
apps.rag.utils
import
query_doc
,
query_collection
,
rag_template
from
apps.rag.utils
import
rag_messages
from
config
import
(
WEBUI_NAME
,
...
...
@@ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
origins
=
[
"*"
]
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
origins
,
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
@
app
.
on_event
(
"startup"
)
async
def
on_startup
():
await
litellm_app_startup
()
class
RAGMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
...
...
@@ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if
"docs"
in
data
:
docs
=
data
[
"docs"
]
print
(
docs
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
data
[
"messages"
])
-
1
,
-
1
,
-
1
):
if
data
[
"messages"
][
i
][
"role"
]
==
"user"
:
last_user_message_idx
=
i
break
user_message
=
data
[
"messages"
][
last_user_message_idx
]
if
isinstance
(
user_message
[
"content"
],
list
):
# Handle list content input
content_type
=
"list"
query
=
""
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
query
=
content_item
[
"text"
]
break
elif
isinstance
(
user_message
[
"content"
],
str
):
# Handle text content input
content_type
=
"text"
query
=
user_message
[
"content"
]
else
:
# Fallback in case the input does not match expected types
content_type
=
None
query
=
""
relevant_contexts
=
[]
for
doc
in
docs
:
context
=
None
try
:
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
except
Exception
as
e
:
print
(
e
)
context
=
None
relevant_contexts
.
append
(
context
)
context_string
=
""
for
context
in
relevant_contexts
:
if
context
:
context_string
+=
" "
.
join
(
context
[
"documents"
][
0
])
+
"
\n
"
ra_content
=
rag_template
(
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
context
=
context_string
,
query
=
query
,
)
if
content_type
==
"list"
:
new_content
=
[]
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
# Update the text item's content with ra_content
new_content
.
append
({
"type"
:
"text"
,
"text"
:
ra_content
})
else
:
# Keep other types of content as they are
new_content
.
append
(
content_item
)
new_user_message
=
{
**
user_message
,
"content"
:
new_content
}
else
:
new_user_message
=
{
**
user_message
,
"content"
:
ra_content
,
}
data
[
"messages"
][
last_user_message_idx
]
=
new_user_message
data
=
{
**
data
}
data
[
"messages"
]
=
rag_messages
(
data
[
"docs"
],
data
[
"messages"
],
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
sentence_transformer_ef
,
)
del
data
[
"docs"
]
print
(
data
[
"messages"
])
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Create a new request with the modified body
scope
=
request
.
scope
scope
[
"body"
]
=
modified_body_bytes
request
=
Request
(
scope
,
receive
=
lambda
:
self
.
_receive
(
modified_body_bytes
))
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
# Set custom header to ensure content-length matches new body length
request
.
headers
.
__dict__
[
"_list"
]
=
[
(
b
"content-length"
,
str
(
len
(
modified_body_bytes
)).
encode
(
"utf-8"
)),
*
[
(
k
,
v
)
for
k
,
v
in
request
.
headers
.
raw
if
k
.
lower
()
!=
b
"content-length"
],
]
response
=
await
call_next
(
request
)
return
response
...
...
@@ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
app
.
add_middleware
(
RAGMiddleware
)
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
origins
,
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
start_time
=
int
(
time
.
time
())
...
...
@@ -204,6 +135,11 @@ async def check_url(request: Request, call_next):
return
response
@
app
.
on_event
(
"startup"
)
async
def
on_startup
():
await
litellm_app_startup
()
app
.
mount
(
"/api/v1"
,
webui_app
)
app
.
mount
(
"/litellm/api"
,
litellm_app
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment