Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
ce9a5d12
Commit
ce9a5d12
authored
Apr 27, 2024
by
Timothy J. Baek
Browse files
refac: rag pipeline
parent
8f1563a7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
177 additions
and
152 deletions
+177
-152
backend/apps/rag/main.py
backend/apps/rag/main.py
+55
-45
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+114
-95
backend/main.py
backend/main.py
+8
-12
No files found.
backend/apps/rag/main.py
View file @
ce9a5d12
...
@@ -47,9 +47,11 @@ from apps.web.models.documents import (
...
@@ -47,9 +47,11 @@ from apps.web.models.documents import (
from
apps.rag.utils
import
(
from
apps.rag.utils
import
(
get_model_path
,
get_model_path
,
query_embeddings_doc
,
get_embedding_function
,
get_embeddings_function
,
query_doc
,
query_embeddings_collection
,
query_doc_with_hybrid_search
,
query_collection
,
query_collection_with_hybrid_search
,
)
)
from
utils.misc
import
(
from
utils.misc
import
(
...
@@ -147,6 +149,15 @@ update_reranking_model(
...
@@ -147,6 +149,15 @@ update_reranking_model(
RAG_RERANKING_MODEL_AUTO_UPDATE
,
RAG_RERANKING_MODEL_AUTO_UPDATE
,
)
)
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
)
origins
=
[
"*"
]
origins
=
[
"*"
]
...
@@ -227,6 +238,14 @@ async def update_embedding_config(
...
@@ -227,6 +238,14 @@ async def update_embedding_config(
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
True
)
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
True
)
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
)
return
{
return
{
"status"
:
True
,
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
...
@@ -367,27 +386,22 @@ def query_doc_handler(
...
@@ -367,27 +386,22 @@ def query_doc_handler(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
try
:
try
:
embeddings_function
=
get_embeddings_function
(
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
return
query_doc_with_hybrid_search
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
collection_name
=
form_data
.
collection_name
,
app
.
state
.
sentence_transformer_ef
,
query
=
form_data
.
query
,
app
.
state
.
OPENAI_API_KEY
,
embeddings_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
app
.
state
.
OPENAI_API_BASE_URL
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
)
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
return
query_embeddings_doc
(
)
collection_name
=
form_data
.
collection_name
,
else
:
query
=
form_data
.
query
,
return
query_doc
(
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
collection_name
=
form_data
.
collection_name
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
query
=
form_data
.
query
,
embeddings_function
=
embeddings_function
,
embeddings_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
hybrid_search
=
(
)
form_data
.
hybrid
if
form_data
.
hybrid
else
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
),
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -410,27 +424,23 @@ def query_collection_handler(
...
@@ -410,27 +424,23 @@ def query_collection_handler(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
try
:
try
:
embeddings_function
=
get_embeddings_function
(
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
return
query_collection_with_hybrid_search
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
collection_names
=
form_data
.
collection_names
,
app
.
state
.
sentence_transformer_ef
,
query
=
form_data
.
query
,
app
.
state
.
OPENAI_API_KEY
,
embeddings_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
app
.
state
.
OPENAI_API_BASE_URL
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
)
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
)
else
:
return
query_collection
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
embeddings_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
)
return
query_embeddings_collection
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
embeddings_function
=
embeddings_function
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
hybrid_search
=
(
form_data
.
hybrid
if
form_data
.
hybrid
else
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
),
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
...
@@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
embedding_func
=
get_embedding
s
_function
(
embedding_func
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
sentence_transformer_ef
,
...
...
backend/apps/rag/utils.py
View file @
ce9a5d12
...
@@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
...
@@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
query_
embeddings_
doc
(
def
query_doc
(
collection_name
:
str
,
collection_name
:
str
,
query
:
str
,
query
:
str
,
embeddings_function
,
embedding_function
,
reranking_function
,
k
:
int
,
k
:
int
,
r
:
int
,
hybrid_search
:
bool
,
):
):
try
:
try
:
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
)
query_embeddings
=
embedding_function
(
query
)
result
=
collection
.
query
(
query_embeddings
=
[
query_embeddings
],
n_results
=
k
,
)
if
hybrid_search
:
log
.
info
(
f
"query_doc:result
{
result
}
"
)
documents
=
collection
.
get
()
# get all documents
return
result
bm25_retriever
=
BM25Retriever
.
from_texts
(
except
Exception
as
e
:
texts
=
documents
.
get
(
"documents"
),
raise
e
metadatas
=
documents
.
get
(
"metadatas"
),
)
bm25_retriever
.
k
=
k
chroma_retriever
=
ChromaRetriever
(
collection
=
collection
,
embeddings_function
=
embeddings_function
,
top_n
=
k
,
)
ensemble_retriever
=
EnsembleRetriever
(
def
query_doc_with_hybrid_search
(
retrievers
=
[
bm25_retriever
,
chroma_retriever
],
weights
=
[
0.5
,
0.5
]
collection_name
:
str
,
)
query
:
str
,
embedding_function
,
k
:
int
,
reranking_function
,
r
:
int
,
):
try
:
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
)
documents
=
collection
.
get
()
# get all documents
compressor
=
RerankCompressor
(
bm25_retriever
=
BM25Retriever
.
from_texts
(
embeddings_function
=
embeddings_function
,
texts
=
documents
.
get
(
"documents"
),
reranking_function
=
reranking_function
,
metadatas
=
documents
.
get
(
"metadatas"
),
r_score
=
r
,
)
top_n
=
k
,
bm25_retriever
.
k
=
k
)
compression_retriever
=
ContextualCompressionRetriever
(
chroma_retriever
=
ChromaRetriever
(
base_compressor
=
compressor
,
base_retriever
=
ensemble_retriever
collection
=
collection
,
)
embedding_function
=
embedding_function
,
top_n
=
k
,
)
result
=
compression_retriever
.
invoke
(
query
)
ensemble_retriever
=
EnsembleRetriever
(
result
=
{
retrievers
=
[
bm25_retriever
,
chroma_retriever
],
weights
=
[
0.5
,
0.5
]
"distances"
:
[[
d
.
metadata
.
get
(
"score"
)
for
d
in
result
]],
)
"documents"
:
[[
d
.
page_content
for
d
in
result
]],
"metadatas"
:
[[
d
.
metadata
for
d
in
result
]],
compressor
=
RerankCompressor
(
}
embedding_function
=
embedding_function
,
else
:
reranking_function
=
reranking_function
,
query_embeddings
=
embeddings_function
(
query
)
r_score
=
r
,
result
=
collection
.
query
(
top_n
=
k
,
query_embeddings
=
[
query_embeddings
],
)
n_results
=
k
,
)
compression_retriever
=
ContextualCompressionRetriever
(
base_compressor
=
compressor
,
base_retriever
=
ensemble_retriever
)
log
.
info
(
f
"query_embeddings_doc:result
{
result
}
"
)
result
=
compression_retriever
.
invoke
(
query
)
result
=
{
"distances"
:
[[
d
.
metadata
.
get
(
"score"
)
for
d
in
result
]],
"documents"
:
[[
d
.
page_content
for
d
in
result
]],
"metadatas"
:
[[
d
.
metadata
for
d
in
result
]],
}
log
.
info
(
f
"query_doc_with_hybrid_search:result
{
result
}
"
)
return
result
return
result
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
...
@@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
...
@@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
return
result
return
result
def
query_
embeddings_
collection
(
def
query_collection
(
collection_names
:
List
[
str
],
collection_names
:
List
[
str
],
query
:
str
,
query
:
str
,
embedding_function
,
k
:
int
,
):
results
=
[]
for
collection_name
in
collection_names
:
try
:
result
=
query_doc
(
collection_name
=
collection_name
,
query
=
query
,
k
=
k
,
embedding_function
=
embedding_function
,
)
results
.
append
(
result
)
except
:
pass
return
merge_and_sort_query_results
(
results
,
k
=
k
)
def
query_collection_with_hybrid_search
(
collection_names
:
List
[
str
],
query
:
str
,
embedding_function
,
k
:
int
,
k
:
int
,
r
:
float
,
embeddings_function
,
reranking_function
,
reranking_function
,
hybrid_search
:
bool
,
r
:
float
,
):
):
results
=
[]
results
=
[]
for
collection_name
in
collection_names
:
for
collection_name
in
collection_names
:
try
:
try
:
result
=
query_
embeddings_doc
(
result
=
query_
doc_with_hybrid_search
(
collection_name
=
collection_name
,
collection_name
=
collection_name
,
query
=
query
,
query
=
query
,
embedding_function
=
embedding_function
,
k
=
k
,
k
=
k
,
r
=
r
,
embeddings_function
=
embeddings_function
,
reranking_function
=
reranking_function
,
reranking_function
=
reranking_function
,
hybrid_search
=
hybrid_search
,
r
=
r
,
)
)
results
.
append
(
result
)
results
.
append
(
result
)
except
:
except
:
pass
pass
reverse
=
hybrid_search
and
reranking_function
is
not
None
return
merge_and_sort_query_results
(
results
,
k
=
k
,
reverse
=
True
)
return
merge_and_sort_query_results
(
results
,
k
=
k
,
reverse
=
reverse
)
def
rag_template
(
template
:
str
,
context
:
str
,
query
:
str
):
def
rag_template
(
template
:
str
,
context
:
str
,
query
:
str
):
...
@@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
...
@@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
return
template
return
template
def
get_embedding
s
_function
(
def
get_embedding_function
(
embedding_engine
,
embedding_engine
,
embedding_model
,
embedding_model
,
embedding_function
,
embedding_function
,
...
@@ -204,19 +232,13 @@ def rag_messages(
...
@@ -204,19 +232,13 @@ def rag_messages(
docs
,
docs
,
messages
,
messages
,
template
,
template
,
embedding_function
,
k
,
k
,
reranking_function
,
r
,
r
,
hybrid_search
,
hybrid_search
,
embedding_engine
,
embedding_model
,
embedding_function
,
reranking_function
,
openai_key
,
openai_url
,
):
):
log
.
debug
(
log
.
debug
(
f
"docs:
{
docs
}
{
messages
}
{
embedding_function
}
{
reranking_function
}
"
)
f
"docs:
{
docs
}
{
messages
}
{
embedding_engine
}
{
embedding_model
}
{
embedding_function
}
{
reranking_function
}
{
openai_key
}
{
openai_url
}
"
)
last_user_message_idx
=
None
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
...
@@ -243,14 +265,6 @@ def rag_messages(
...
@@ -243,14 +265,6 @@ def rag_messages(
content_type
=
None
content_type
=
None
query
=
""
query
=
""
embeddings_function
=
get_embeddings_function
(
embedding_engine
,
embedding_model
,
embedding_function
,
openai_key
,
openai_url
,
)
extracted_collections
=
[]
extracted_collections
=
[]
relevant_contexts
=
[]
relevant_contexts
=
[]
...
@@ -271,26 +285,31 @@ def rag_messages(
...
@@ -271,26 +285,31 @@ def rag_messages(
try
:
try
:
if
doc
[
"type"
]
==
"text"
:
if
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
context
=
doc
[
"content"
]
elif
doc
[
"type"
]
==
"collection"
:
context
=
query_embeddings_collection
(
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
k
,
r
=
r
,
embeddings_function
=
embeddings_function
,
reranking_function
=
reranking_function
,
hybrid_search
=
hybrid_search
,
)
else
:
else
:
context
=
query_embeddings_doc
(
if
hybrid_search
:
collection_name
=
doc
[
"collection_name"
],
context
=
query_collection_with_hybrid_search
(
query
=
query
,
collection_names
=
(
k
=
k
,
doc
[
"collection_names"
]
r
=
r
,
if
doc
[
"type"
]
==
"collection"
embeddings_function
=
embeddings_function
,
else
[
doc
[
"collection_name"
]]
reranking_function
=
reranking_function
,
),
hybrid_search
=
hybrid_search
,
query
=
query
,
)
embedding_function
=
embedding_function
,
k
=
k
,
reranking_function
=
reranking_function
,
r
=
r
,
)
else
:
context
=
query_collection
(
collection_names
=
(
doc
[
"collection_names"
]
if
doc
[
"type"
]
==
"collection"
else
[
doc
[
"collection_name"
]]
),
query
=
query
,
embedding_function
=
embedding_function
,
k
=
k
,
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
context
=
None
context
=
None
...
@@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
...
@@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
class
ChromaRetriever
(
BaseRetriever
):
class
ChromaRetriever
(
BaseRetriever
):
collection
:
Any
collection
:
Any
embedding
s
_function
:
Any
embedding_function
:
Any
top_n
:
int
top_n
:
int
def
_get_relevant_documents
(
def
_get_relevant_documents
(
...
@@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
...
@@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
*
,
*
,
run_manager
:
CallbackManagerForRetrieverRun
,
run_manager
:
CallbackManagerForRetrieverRun
,
)
->
List
[
Document
]:
)
->
List
[
Document
]:
query_embeddings
=
self
.
embedding
s
_function
(
query
)
query_embeddings
=
self
.
embedding_function
(
query
)
results
=
self
.
collection
.
query
(
results
=
self
.
collection
.
query
(
query_embeddings
=
[
query_embeddings
],
query_embeddings
=
[
query_embeddings
],
...
@@ -445,7 +464,7 @@ from sentence_transformers import util
...
@@ -445,7 +464,7 @@ from sentence_transformers import util
class
RerankCompressor
(
BaseDocumentCompressor
):
class
RerankCompressor
(
BaseDocumentCompressor
):
embedding
s
_function
:
Any
embedding_function
:
Any
reranking_function
:
Any
reranking_function
:
Any
r_score
:
float
r_score
:
float
top_n
:
int
top_n
:
int
...
@@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
...
@@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(
query
,
doc
.
page_content
)
for
doc
in
documents
]
[(
query
,
doc
.
page_content
)
for
doc
in
documents
]
)
)
else
:
else
:
query_embedding
=
self
.
embedding
s
_function
(
query
)
query_embedding
=
self
.
embedding_function
(
query
)
document_embedding
=
self
.
embedding
s
_function
(
document_embedding
=
self
.
embedding_function
(
[
doc
.
page_content
for
doc
in
documents
]
[
doc
.
page_content
for
doc
in
documents
]
)
)
scores
=
util
.
cos_sim
(
query_embedding
,
document_embedding
)[
0
]
scores
=
util
.
cos_sim
(
query_embedding
,
document_embedding
)[
0
]
...
...
backend/main.py
View file @
ce9a5d12
...
@@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
if
"docs"
in
data
:
if
"docs"
in
data
:
data
=
{
**
data
}
data
=
{
**
data
}
data
[
"messages"
]
=
rag_messages
(
data
[
"messages"
]
=
rag_messages
(
data
[
"docs"
],
docs
=
data
[
"docs"
],
data
[
"messages"
],
messages
=
data
[
"messages"
],
rag_app
.
state
.
RAG_TEMPLATE
,
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
rag_app
.
state
.
RELEVANCE_THRESHOLD
,
k
=
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
rag_app
.
state
.
RAG_EMBEDDING_ENGINE
,
r
=
rag_app
.
state
.
RELEVANCE_THRESHOLD
,
rag_app
.
state
.
RAG_EMBEDDING_MODEL
,
hybrid_search
=
rag_app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
rag_app
.
state
.
sentence_transformer_ef
,
rag_app
.
state
.
sentence_transformer_rf
,
rag_app
.
state
.
OPENAI_API_KEY
,
rag_app
.
state
.
OPENAI_API_BASE_URL
,
)
)
del
data
[
"docs"
]
del
data
[
"docs"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment