Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c49491e5
Commit
c49491e5
authored
Mar 08, 2024
by
Timothy J. Baek
Browse files
refac: rag to backend
parent
6ba62cf2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
50 deletions
+113
-50
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+8
-0
backend/main.py
backend/main.py
+86
-0
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+1
-1
src/routes/(app)/+page.svelte
src/routes/(app)/+page.svelte
+18
-49
No files found.
backend/apps/rag/utils.py
View file @
c49491e5
import
re
from
typing
import
List
from
config
import
CHROMA_CLIENT
...
...
@@ -87,3 +88,10 @@ def query_collection(
pass
return
merge_and_sort_query_results
(
results
,
k
)
def
rag_template
(
template
:
str
,
context
:
str
,
query
:
str
):
template
=
re
.
sub
(
r
"\[context\]"
,
context
,
template
)
template
=
re
.
sub
(
r
"\[query\]"
,
query
,
template
)
return
template
backend/main.py
View file @
c49491e5
...
...
@@ -12,6 +12,7 @@ from fastapi import HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
apps.ollama.main
import
app
as
ollama_app
...
...
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from
apps.web.main
import
app
as
webui_app
from
apps.rag.utils
import
query_doc
,
query_collection
,
rag_template
from
config
import
WEBUI_NAME
,
ENV
,
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
from
constants
import
ERROR_MESSAGES
...
...
@@ -56,6 +59,89 @@ async def on_startup():
await
litellm_app_startup
()
class
RAGMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
print
(
request
.
url
.
path
)
if
request
.
method
==
"POST"
:
# Read the original request body
body
=
await
request
.
body
()
# Decode body to string
body_str
=
body
.
decode
(
"utf-8"
)
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if
"docs"
in
data
:
docs
=
data
[
"docs"
]
print
(
docs
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
data
[
"messages"
])
-
1
,
-
1
,
-
1
):
if
data
[
"messages"
][
i
][
"role"
]
==
"user"
:
last_user_message_idx
=
i
break
query
=
data
[
"messages"
][
last_user_message_idx
][
"content"
]
relevant_contexts
=
[]
for
doc
in
docs
:
context
=
None
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
relevant_contexts
.
append
(
context
)
context_string
=
""
for
context
in
relevant_contexts
:
if
context
:
context_string
+=
" "
.
join
(
context
[
"documents"
][
0
])
+
"
\n
"
content
=
rag_template
(
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
context
=
context_string
,
query
=
query
,
)
new_user_message
=
{
**
data
[
"messages"
][
last_user_message_idx
],
"content"
:
content
,
}
data
[
"messages"
][
last_user_message_idx
]
=
new_user_message
del
data
[
"docs"
]
print
(
"DATAAAAAAAAAAAAAAAAAA"
)
print
(
data
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Create a new request with the modified body
scope
=
request
.
scope
scope
[
"body"
]
=
modified_body_bytes
request
=
Request
(
scope
,
receive
=
lambda
:
self
.
_receive
(
modified_body_bytes
))
response
=
await
call_next
(
request
)
return
response
async
def
_receive
(
self
,
body
:
bytes
):
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
app
.
add_middleware
(
RAGMiddleware
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
start_time
=
int
(
time
.
time
())
...
...
src/lib/apis/rag/index.ts
View file @
c49491e5
...
...
@@ -252,7 +252,7 @@ export const queryCollection = async (
token
:
string
,
collection_names
:
string
,
query
:
string
,
k
:
number
k
:
number
|
null
=
null
)
=>
{
let
error
=
null
;
...
...
src/routes/(app)/+page.svelte
View file @
c49491e5
...
...
@@ -232,53 +232,6 @@
const
sendPrompt
=
async
(
prompt
,
parentId
)
=>
{
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
if
(
docs
.
length
>
0
)
{
processing
=
'Reading'
;
const
query
=
history
.
messages
[
parentId
].
content
;
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
console
.
log
(
error
);
return
null
;
}
);
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
console
.
log
(
error
);
return
null
;
});
}
})
);
relevantContexts
=
relevantContexts
.
filter
((
context
)
=>
context
);
const
contextString
=
relevantContexts
.
reduce
((
a
,
context
,
i
,
arr
)
=>
{
return
`${
a
}${
context
.
documents
.
join
(
' '
)}\
n
`;
},
''
);
console
.
log
(
contextString
);
history
.
messages
[
parentId
].
raContent
=
await
RAGTemplate
(
localStorage
.
token
,
contextString
,
query
);
history
.
messages
[
parentId
].
contexts
=
relevantContexts
;
await
tick
();
processing
=
''
;
}
await
Promise
.
all
(
selectedModels
.
map
(
async
(
modelId
)
=>
{
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
...
...
@@ -368,6 +321,13 @@
}
});
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
[
res
,
controller
]
=
await
generateChatCompletion
(
localStorage
.
token
,
{
model
:
model
,
messages
:
messagesBody
,
...
...
@@ -375,7 +335,8 @@
...($
settings
.
options
??
{})
},
format
:
$
settings
.
requestFormat
??
undefined
,
keep_alive
:
$
settings
.
keepAlive
??
undefined
keep_alive
:
$
settings
.
keepAlive
??
undefined
,
docs
:
docs
});
if
(
res
&&
res
.
ok
)
{
...
...
@@ -535,6 +496,13 @@
const
responseMessage
=
history
.
messages
[
responseMessageId
];
scrollToBottom
();
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
res
=
await
generateOpenAIChatCompletion
(
localStorage
.
token
,
{
...
...
@@ -583,7 +551,8 @@
top_p
:
$
settings
?.
options
?.
top_p
??
undefined
,
num_ctx
:
$
settings
?.
options
?.
num_ctx
??
undefined
,
frequency_penalty
:
$
settings
?.
options
?.
repeat_penalty
??
undefined
,
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
,
docs
:
docs
},
model
.
source
===
'litellm'
?
`${
LITELLM_API_BASE_URL
}/
v1
`
:
`${
OPENAI_API_BASE_URL
}`
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment