Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
d9911451
Unverified
Commit
d9911451
authored
Mar 09, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Mar 09, 2024
Browse files
Merge pull request #1113 from open-webui/rag
feat: rag api
parents
6ba62cf2
784ee6f5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
209 additions
and
117 deletions
+209
-117
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+8
-0
backend/main.py
backend/main.py
+121
-0
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+1
-1
src/routes/(app)/+page.svelte
src/routes/(app)/+page.svelte
+39
-58
src/routes/(app)/c/[id]/+page.svelte
src/routes/(app)/c/[id]/+page.svelte
+40
-58
No files found.
backend/apps/rag/utils.py
View file @
d9911451
import
re
from
typing
import
List
from
config
import
CHROMA_CLIENT
...
...
@@ -87,3 +88,10 @@ def query_collection(
pass
return
merge_and_sort_query_results
(
results
,
k
)
def
rag_template
(
template
:
str
,
context
:
str
,
query
:
str
):
template
=
re
.
sub
(
r
"\[context\]"
,
context
,
template
)
template
=
re
.
sub
(
r
"\[query\]"
,
query
,
template
)
return
template
backend/main.py
View file @
d9911451
...
...
@@ -12,6 +12,7 @@ from fastapi import HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
apps.ollama.main
import
app
as
ollama_app
...
...
@@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from
apps.web.main
import
app
as
webui_app
from
apps.rag.utils
import
query_doc
,
query_collection
,
rag_template
from
config
import
WEBUI_NAME
,
ENV
,
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
from
constants
import
ERROR_MESSAGES
...
...
@@ -56,6 +59,124 @@ async def on_startup():
await
litellm_app_startup
()
class
RAGMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
if
request
.
method
==
"POST"
and
(
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
):
print
(
request
.
url
.
path
)
# Read the original request body
body
=
await
request
.
body
()
# Decode body to string
body_str
=
body
.
decode
(
"utf-8"
)
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if
"docs"
in
data
:
docs
=
data
[
"docs"
]
print
(
docs
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
data
[
"messages"
])
-
1
,
-
1
,
-
1
):
if
data
[
"messages"
][
i
][
"role"
]
==
"user"
:
last_user_message_idx
=
i
break
user_message
=
data
[
"messages"
][
last_user_message_idx
]
if
isinstance
(
user_message
[
"content"
],
list
):
# Handle list content input
content_type
=
"list"
query
=
""
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
query
=
content_item
[
"text"
]
break
elif
isinstance
(
user_message
[
"content"
],
str
):
# Handle text content input
content_type
=
"text"
query
=
user_message
[
"content"
]
else
:
# Fallback in case the input does not match expected types
content_type
=
None
query
=
""
relevant_contexts
=
[]
for
doc
in
docs
:
context
=
None
try
:
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
sentence_transformer_ef
,
)
except
Exception
as
e
:
print
(
e
)
context
=
None
relevant_contexts
.
append
(
context
)
context_string
=
""
for
context
in
relevant_contexts
:
if
context
:
context_string
+=
" "
.
join
(
context
[
"documents"
][
0
])
+
"
\n
"
ra_content
=
rag_template
(
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
context
=
context_string
,
query
=
query
,
)
if
content_type
==
"list"
:
new_content
=
[]
for
content_item
in
user_message
[
"content"
]:
if
content_item
[
"type"
]
==
"text"
:
# Update the text item's content with ra_content
new_content
.
append
({
"type"
:
"text"
,
"text"
:
ra_content
})
else
:
# Keep other types of content as they are
new_content
.
append
(
content_item
)
new_user_message
=
{
**
user_message
,
"content"
:
new_content
}
else
:
new_user_message
=
{
**
user_message
,
"content"
:
ra_content
,
}
data
[
"messages"
][
last_user_message_idx
]
=
new_user_message
del
data
[
"docs"
]
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Create a new request with the modified body
scope
=
request
.
scope
scope
[
"body"
]
=
modified_body_bytes
request
=
Request
(
scope
,
receive
=
lambda
:
self
.
_receive
(
modified_body_bytes
))
response
=
await
call_next
(
request
)
return
response
async
def
_receive
(
self
,
body
:
bytes
):
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
app
.
add_middleware
(
RAGMiddleware
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
start_time
=
int
(
time
.
time
())
...
...
src/lib/apis/rag/index.ts
View file @
d9911451
...
...
@@ -252,7 +252,7 @@ export const queryCollection = async (
token
:
string
,
collection_names
:
string
,
query
:
string
,
k
:
number
k
:
number
|
null
=
null
)
=>
{
let
error
=
null
;
...
...
src/routes/(app)/+page.svelte
View file @
d9911451
...
...
@@ -232,53 +232,6 @@
const
sendPrompt
=
async
(
prompt
,
parentId
)
=>
{
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
if
(
docs
.
length
>
0
)
{
processing
=
'Reading'
;
const
query
=
history
.
messages
[
parentId
].
content
;
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
console
.
log
(
error
);
return
null
;
}
);
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
console
.
log
(
error
);
return
null
;
});
}
})
);
relevantContexts
=
relevantContexts
.
filter
((
context
)
=>
context
);
const
contextString
=
relevantContexts
.
reduce
((
a
,
context
,
i
,
arr
)
=>
{
return
`${
a
}${
context
.
documents
.
join
(
' '
)}\
n
`;
},
''
);
console
.
log
(
contextString
);
history
.
messages
[
parentId
].
raContent
=
await
RAGTemplate
(
localStorage
.
token
,
contextString
,
query
);
history
.
messages
[
parentId
].
contexts
=
relevantContexts
;
await
tick
();
processing
=
''
;
}
await
Promise
.
all
(
selectedModels
.
map
(
async
(
modelId
)
=>
{
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
...
...
@@ -342,15 +295,25 @@
...
messages
]
.
filter
((
message
)
=>
message
)
.
map
((
message
,
idx
,
arr
)
=>
({
role
:
message
.
role
,
content
:
arr
.
length
-
2
!== idx ? message.content : message?.raContent ?? message.content,
...(
message
.
files
&&
{
images
:
message
.
files
.
filter
((
file
)
=>
file
.
type
===
'image'
)
.
map
((
file
)
=>
file
.
url
.
slice
(
file
.
url
.
indexOf
(
','
)
+
1
))
})
}));
.
map
((
message
,
idx
,
arr
)
=>
{
//
Prepare
the
base
message
object
const
baseMessage
=
{
role
:
message
.
role
,
content
:
arr
.
length
-
2
!== idx ? message.content : message?.raContent ?? message.content
};
//
Extract
and
format
image
URLs
if
any
exist
const
imageUrls
=
message
.
files
?.
filter
((
file
)
=>
file
.
type
===
'image'
)
.
map
((
file
)
=>
file
.
url
.
slice
(
file
.
url
.
indexOf
(
','
)
+
1
));
//
Add
images
array
only
if
it
contains
elements
if
(
imageUrls
&&
imageUrls
.
length
>
0
)
{
baseMessage
.
images
=
imageUrls
;
}
return
baseMessage
;
});
let
lastImageIndex
=
-
1
;
...
...
@@ -368,6 +331,13 @@
}
});
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
[
res
,
controller
]
=
await
generateChatCompletion
(
localStorage
.
token
,
{
model
:
model
,
messages
:
messagesBody
,
...
...
@@ -375,7 +345,8 @@
...($
settings
.
options
??
{})
},
format
:
$
settings
.
requestFormat
??
undefined
,
keep_alive
:
$
settings
.
keepAlive
??
undefined
keep_alive
:
$
settings
.
keepAlive
??
undefined
,
docs
:
docs
.
length
>
0
?
docs
:
undefined
});
if
(
res
&&
res
.
ok
)
{
...
...
@@ -535,6 +506,15 @@
const
responseMessage
=
history
.
messages
[
responseMessageId
];
scrollToBottom
();
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
const
res
=
await
generateOpenAIChatCompletion
(
localStorage
.
token
,
{
...
...
@@ -583,7 +563,8 @@
top_p
:
$
settings
?.
options
?.
top_p
??
undefined
,
num_ctx
:
$
settings
?.
options
?.
num_ctx
??
undefined
,
frequency_penalty
:
$
settings
?.
options
?.
repeat_penalty
??
undefined
,
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
,
docs
:
docs
.
length
>
0
?
docs
:
undefined
},
model
.
source
===
'litellm'
?
`${
LITELLM_API_BASE_URL
}/
v1
`
:
`${
OPENAI_API_BASE_URL
}`
);
...
...
src/routes/(app)/c/[id]/+page.svelte
View file @
d9911451
...
...
@@ -245,53 +245,6 @@
const
sendPrompt
=
async
(
prompt
,
parentId
)
=>
{
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
if
(
docs
.
length
>
0
)
{
processing
=
'Reading'
;
const
query
=
history
.
messages
[
parentId
].
content
;
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
console
.
log
(
error
);
return
null
;
}
);
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
console
.
log
(
error
);
return
null
;
});
}
})
);
relevantContexts
=
relevantContexts
.
filter
((
context
)
=>
context
);
const
contextString
=
relevantContexts
.
reduce
((
a
,
context
,
i
,
arr
)
=>
{
return
`${
a
}${
context
.
documents
.
join
(
' '
)}\
n
`;
},
''
);
console
.
log
(
contextString
);
history
.
messages
[
parentId
].
raContent
=
await
RAGTemplate
(
localStorage
.
token
,
contextString
,
query
);
history
.
messages
[
parentId
].
contexts
=
relevantContexts
;
await
tick
();
processing
=
''
;
}
await
Promise
.
all
(
selectedModels
.
map
(
async
(
modelId
)
=>
{
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
...
...
@@ -355,15 +308,25 @@
...
messages
]
.
filter
((
message
)
=>
message
)
.
map
((
message
,
idx
,
arr
)
=>
({
role
:
message
.
role
,
content
:
arr
.
length
-
2
!== idx ? message.content : message?.raContent ?? message.content,
...(
message
.
files
&&
{
images
:
message
.
files
.
filter
((
file
)
=>
file
.
type
===
'image'
)
.
map
((
file
)
=>
file
.
url
.
slice
(
file
.
url
.
indexOf
(
','
)
+
1
))
})
}));
.
map
((
message
,
idx
,
arr
)
=>
{
//
Prepare
the
base
message
object
const
baseMessage
=
{
role
:
message
.
role
,
content
:
arr
.
length
-
2
!== idx ? message.content : message?.raContent ?? message.content
};
//
Extract
and
format
image
URLs
if
any
exist
const
imageUrls
=
message
.
files
?.
filter
((
file
)
=>
file
.
type
===
'image'
)
.
map
((
file
)
=>
file
.
url
.
slice
(
file
.
url
.
indexOf
(
','
)
+
1
));
//
Add
images
array
only
if
it
contains
elements
if
(
imageUrls
&&
imageUrls
.
length
>
0
)
{
baseMessage
.
images
=
imageUrls
;
}
return
baseMessage
;
});
let
lastImageIndex
=
-
1
;
...
...
@@ -381,6 +344,13 @@
}
});
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
const
[
res
,
controller
]
=
await
generateChatCompletion
(
localStorage
.
token
,
{
model
:
model
,
messages
:
messagesBody
,
...
...
@@ -388,7 +358,8 @@
...($
settings
.
options
??
{})
},
format
:
$
settings
.
requestFormat
??
undefined
,
keep_alive
:
$
settings
.
keepAlive
??
undefined
keep_alive
:
$
settings
.
keepAlive
??
undefined
,
docs
:
docs
.
length
>
0
?
docs
:
undefined
});
if
(
res
&&
res
.
ok
)
{
...
...
@@ -548,6 +519,15 @@
const
responseMessage
=
history
.
messages
[
responseMessageId
];
scrollToBottom
();
const
docs
=
messages
.
filter
((
message
)
=>
message
?.
files
??
null
)
.
map
((
message
)
=>
message
.
files
.
filter
((
item
)
=>
item
.
type
===
'doc'
||
item
.
type
===
'collection'
)
)
.
flat
(
1
);
console
.
log
(
docs
);
const
res
=
await
generateOpenAIChatCompletion
(
localStorage
.
token
,
{
...
...
@@ -596,7 +576,8 @@
top_p
:
$
settings
?.
options
?.
top_p
??
undefined
,
num_ctx
:
$
settings
?.
options
?.
num_ctx
??
undefined
,
frequency_penalty
:
$
settings
?.
options
?.
repeat_penalty
??
undefined
,
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
max_tokens
:
$
settings
?.
options
?.
num_predict
??
undefined
,
docs
:
docs
.
length
>
0
?
docs
:
undefined
},
model
.
source
===
'litellm'
?
`${
LITELLM_API_BASE_URL
}/
v1
`
:
`${
OPENAI_API_BASE_URL
}`
);
...
...
@@ -710,6 +691,7 @@
await
setChatTitle
(
_chatId
,
userPrompt
);
}
};
const
stopResponse
=
()
=>
{
stopResponseFlag
=
true
;
console
.
log
(
'stopResponse'
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment