Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c49491e5
Commit
c49491e5
authored
Mar 08, 2024
by
Timothy J. Baek
Browse files
refac: rag to backend
parent
6ba62cf2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
50 deletions
+113
-50
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+8
-0
backend/main.py
backend/main.py
+86
-0
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+1
-1
src/routes/(app)/+page.svelte
src/routes/(app)/+page.svelte
+18
-49
No files found.
backend/apps/rag/utils.py
View file @
c49491e5
import
re
from
typing
import
List
from
typing
import
List
from
config
import
CHROMA_CLIENT
from
config
import
CHROMA_CLIENT
...
@@ -87,3 +88,10 @@ def query_collection(
...
@@ -87,3 +88,10 @@ def query_collection(
pass
pass
return
merge_and_sort_query_results
(
results
,
k
)
return
merge_and_sort_query_results
(
results
,
k
)
def
rag_template
(
template
:
str
,
context
:
str
,
query
:
str
):
template
=
re
.
sub
(
r
"\[context\]"
,
context
,
template
)
template
=
re
.
sub
(
r
"\[query\]"
,
query
,
template
)
return
template
backend/main.py
View file @
c49491e5
...
@@ -12,6 +12,7 @@ from fastapi import HTTPException
...
@@ -12,6 +12,7 @@ from fastapi import HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
apps.ollama.main
import
app
as
ollama_app
from
apps.ollama.main
import
app
as
ollama_app
...
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
...
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from
apps.web.main
import
app
as
webui_app
from
apps.web.main
import
app
as
webui_app
from
apps.rag.utils
import
query_doc
,
query_collection
,
rag_template
from
config
import
WEBUI_NAME
,
ENV
,
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
from
config
import
WEBUI_NAME
,
ENV
,
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
...
@@ -56,6 +59,89 @@ async def on_startup():
...
@@ -56,6 +59,89 @@ async def on_startup():
await
litellm_app_startup
()
await
litellm_app_startup
()
class
RAGMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
print
(
request
.
url
.
path
)
if
request
.
method
==
"POST"
:
# Read the original request body
body
=
await
request
.
body
()
# Decode body to string
body_str
=
body
.
decode
(
"utf-8"
)
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if
"docs"
in
data
:
docs
=
data
[
"docs"
]
print
(
docs
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
data
[
"messages"
])
-
1
,
-
1
,
-
1
):
if
data
[
"messages"
][
i
][
"role"
]
==
"user"
:
last_user_message_idx
=
i
break
query
=
data
[
"messages"
][
last_user_message_idx
][
"content"
]
relevant_contexts
=
[]
for
doc
in
docs
:
context
=
None
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
relevant_contexts
.
append
(
context
)
context_string
=
""
for
context
in
relevant_contexts
:
if
context
:
context_string
+=
" "
.
join
(
context
[
"documents"
][
0
])
+
"
\n
"
content
=
rag_template
(
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
context
=
context_string
,
query
=
query
,
)
new_user_message
=
{
**
data
[
"messages"
][
last_user_message_idx
],
"content"
:
content
,
}
data
[
"messages"
][
last_user_message_idx
]
=
new_user_message
del
data
[
"docs"
]
print
(
"DATAAAAAAAAAAAAAAAAAA"
)
print
(
data
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Create a new request with the modified body
scope
=
request
.
scope
scope
[
"body"
]
=
modified_body_bytes
request
=
Request
(
scope
,
receive
=
lambda
:
self
.
_receive
(
modified_body_bytes
))
response
=
await
call_next
(
request
)
return
response
async
def
_receive
(
self
,
body
:
bytes
):
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
app
.
add_middleware
(
RAGMiddleware
)
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
async
def
check_url
(
request
:
Request
,
call_next
):
start_time
=
int
(
time
.
time
())
start_time
=
int
(
time
.
time
())
...
...
src/lib/apis/rag/index.ts
View file @
c49491e5
...
@@ -252,7 +252,7 @@ export const queryCollection = async (
...
@@ -252,7 +252,7 @@ export const queryCollection = async (
token
:
string
,
token
:
string
,
collection_names
:
string
,
collection_names
:
string
,
query
:
string
,
query
:
string
,
k
:
number
k
:
number
|
null
=
null
)
=>
{
)
=>
{
let
error
=
null
;
let
error
=
null
;
...
...
src/routes/(app)/+page.svelte
View file @
c49491e5
...
@@ -232,53 +232,6 @@
...
@@ -232,53 +232,6 @@
const
sendPrompt
=
async
(
prompt
,
parentId
)
=>
{
const
sendPrompt
=
async
(
prompt
,
parentId
)
=>
{
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
if
(
docs
.
length
>
0
)
{
processing
=
'Reading'
;
const
query
=
history
.
messages
[
parentId
].
content
;
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
console
.
log
(
error
);
return
null
;
}
);
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
console
.
log
(
error
);
return
null
;
});
}
})
);
relevantContexts
=
relevantContexts
.
filter
((
context
)
=>
context
);
const
contextString
=
relevantContexts
.
reduce
((
a
,
context
,
i
,
arr
)
=>
{
return
`${
a
}${
context
.
documents
.
join
(
' '
)}\
n
`;
},
''
);
console
.
log
(
contextString
);
history
.
messages
[
parentId
].
raContent
=
await
RAGTemplate
(
localStorage
.
token
,
contextString
,
query
);
history
.
messages
[
parentId
].
contexts
=
relevantContexts
;
await
tick
();
processing
=
''
;
}
await
Promise
.
all
(
await
Promise
.
all
(
selectedModels
.
map
(
async
(
modelId
)
=>
{
selectedModels
.
map
(
async
(
modelId
)
=>
{
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
...
@@ -368,6 +321,13 @@
...
@@ -368,6 +321,13 @@
}
}
});
});
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
[
res
,
controller
]
=
await
generateChatCompletion
(
localStorage
.
token
,
{
const
[
res
,
controller
]
=
await
generateChatCompletion
(
localStorage
.
token
,
{
model
:
model
,
model
:
model
,
messages
:
messagesBody
,
messages
:
messagesBody
,
...
@@ -375,7 +335,8 @@
...
@@ -375,7 +335,8 @@
...($
settings
.
options
??
{})
...($
settings
.
options
??
{})
},
},
format
:
$
settings
.
requestFormat
??
undefined
,
format
:
$
settings
.
requestFormat
??
undefined
,
keep_alive
:
$
settings
.
keepAlive
??
undefined
keep_alive
:
$
settings
.
keepAlive
??
undefined
,
docs
:
docs
});
});
if
(
res
&&
res
.
ok
)
{
if
(
res
&&
res
.
ok
)
{
...
@@ -535,6 +496,13 @@
...
@@ -535,6 +496,13 @@
const
responseMessage
=
history
.
messages
[
responseMessageId
];
const
responseMessage
=
history
.
messages
[
responseMessageId
];
scrollToBottom
();
scrollToBottom
();
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
res
=
await
generateOpenAIChatCompletion
(
const
res
=
await
generateOpenAIChatCompletion
(
localStorage
.
token
,
localStorage
.
token
,
{
{
...
@@ -583,7 +551,8 @@
...
@@ -583,7 +551,8 @@
top_p
:
$
settings
?.
options
?.
top_p
??
undefined
,
top_p
:
$
settings
?.
options
?.
top_p
??
undefined
,
num_ctx
:
$
settings
?.
options
?.
num_ctx
??
undefined
,
num_ctx
:
$
settings
?.
options
?.
num_ctx
??
undefined
,
frequency_penalty
:
$
settings
?.
options
?.
repeat_penalty
??
undefined
,
frequency_penalty
:
$
settings
?.
options
?.
repeat_penalty
??
undefined
,
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
,
docs
:
docs
},
},
model
.
source
===
'litellm'
?
`${
LITELLM_API_BASE_URL
}/
v1
`
:
`${
OPENAI_API_BASE_URL
}`
model
.
source
===
'litellm'
?
`${
LITELLM_API_BASE_URL
}/
v1
`
:
`${
OPENAI_API_BASE_URL
}`
);
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment