Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c9c96604
Commit
c9c96604
authored
Apr 25, 2024
by
Steven Kreitzer
Browse files
fix: address comment in pr #1687
parent
d5f60b11
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
44 deletions
+93
-44
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+0
-4
backend/apps/rag/main.py
backend/apps/rag/main.py
+44
-38
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+41
-2
backend/config.py
backend/config.py
+8
-0
No files found.
backend/apps/ollama/main.py
View file @
c9c96604
...
@@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
...
@@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
def
get_ollama_endpoint
(
url_idx
:
int
=
0
):
return
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
class
UrlUpdateForm
(
BaseModel
):
class
UrlUpdateForm
(
BaseModel
):
urls
:
List
[
str
]
urls
:
List
[
str
]
...
...
backend/apps/rag/main.py
View file @
c9c96604
...
@@ -39,8 +39,6 @@ import json
...
@@ -39,8 +39,6 @@ import json
import
sentence_transformers
import
sentence_transformers
from
apps.ollama.main
import
generate_ollama_embeddings
,
GenerateEmbeddingsForm
from
apps.web.models.documents
import
(
from
apps.web.models.documents
import
(
Documents
,
Documents
,
DocumentForm
,
DocumentForm
,
...
@@ -48,6 +46,7 @@ from apps.web.models.documents import (
...
@@ -48,6 +46,7 @@ from apps.web.models.documents import (
)
)
from
apps.rag.utils
import
(
from
apps.rag.utils
import
(
get_model_path
,
query_embeddings_doc
,
query_embeddings_doc
,
query_embeddings_function
,
query_embeddings_function
,
query_embeddings_collection
,
query_embeddings_collection
,
...
@@ -60,6 +59,7 @@ from utils.misc import (
...
@@ -60,6 +59,7 @@ from utils.misc import (
extract_folders_after_data_docs
,
extract_folders_after_data_docs
,
)
)
from
utils.utils
import
get_current_user
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_admin_user
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
UPLOAD_DIR
,
UPLOAD_DIR
,
...
@@ -68,8 +68,10 @@ from config import (
...
@@ -68,8 +68,10 @@ from config import (
RAG_RELEVANCE_THRESHOLD
,
RAG_RELEVANCE_THRESHOLD
,
RAG_EMBEDDING_ENGINE
,
RAG_EMBEDDING_ENGINE
,
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL_AUTO_UPDATE
,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
,
RAG_RERANKING_MODEL
,
RAG_RERANKING_MODEL
,
RAG_RERANKING_MODEL_AUTO_UPDATE
,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
,
RAG_OPENAI_API_BASE_URL
,
RAG_OPENAI_API_BASE_URL
,
RAG_OPENAI_API_KEY
,
RAG_OPENAI_API_KEY
,
...
@@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
...
@@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
state
.
TOP_K
=
RAG_TOP_K
app
.
state
.
TOP_K
=
RAG_TOP_K
app
.
state
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
app
.
state
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
...
@@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
...
@@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
app
.
state
.
sentence_transformer_ef
=
sentence_transformers
.
SentenceTransformer
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
device
=
DEVICE_TYPE
,
trust_remote_code
=
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
,
)
else
:
app
.
state
.
sentence_transformer_ef
=
None
if
not
app
.
state
.
RAG_RERANKING_MODEL
==
""
:
app
.
state
.
sentence_transformer_rf
=
sentence_transformers
.
CrossEncoder
(
app
.
state
.
RAG_RERANKING_MODEL
,
device
=
DEVICE_TYPE
,
trust_remote_code
=
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
,
)
else
:
app
.
state
.
sentence_transformer_rf
=
None
def
update_embedding_model
(
embedding_model
:
str
,
update_model
:
bool
=
False
,
):
if
embedding_model
and
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
app
.
state
.
sentence_transformer_ef
=
sentence_transformers
.
SentenceTransformer
(
get_model_path
(
embedding_model
,
update_model
),
device
=
DEVICE_TYPE
,
trust_remote_code
=
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
,
)
else
:
app
.
state
.
sentence_transformer_ef
=
None
def
update_reranking_model
(
reranking_model
:
str
,
update_model
:
bool
=
False
,
):
if
reranking_model
:
app
.
state
.
sentence_transformer_rf
=
sentence_transformers
.
CrossEncoder
(
get_model_path
(
reranking_model
,
update_model
),
device
=
DEVICE_TYPE
,
trust_remote_code
=
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
,
)
else
:
app
.
state
.
sentence_transformer_rf
=
None
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL_AUTO_UPDATE
,
)
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
RAG_RERANKING_MODEL_AUTO_UPDATE
,
)
origins
=
[
"*"
]
origins
=
[
"*"
]
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
allow_origins
=
origins
,
allow_origins
=
origins
,
...
@@ -200,15 +221,7 @@ async def update_embedding_config(
...
@@ -200,15 +221,7 @@ async def update_embedding_config(
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
app
.
state
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
app
.
state
.
sentence_transformer_ef
=
None
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
True
)
else
:
app
.
state
.
sentence_transformer_ef
=
(
sentence_transformers
.
SentenceTransformer
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
device
=
DEVICE_TYPE
,
trust_remote_code
=
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
,
)
)
return
{
return
{
"status"
:
True
,
"status"
:
True
,
...
@@ -219,7 +232,6 @@ async def update_embedding_config(
...
@@ -219,7 +232,6 @@ async def update_embedding_config(
"key"
:
app
.
state
.
OPENAI_API_KEY
,
"key"
:
app
.
state
.
OPENAI_API_KEY
,
},
},
}
}
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
f
"Problem updating embedding model:
{
e
}
"
)
log
.
exception
(
f
"Problem updating embedding model:
{
e
}
"
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -242,13 +254,7 @@ async def update_reranking_config(
...
@@ -242,13 +254,7 @@ async def update_reranking_config(
try
:
try
:
app
.
state
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
app
.
state
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
if
app
.
state
.
RAG_RERANKING_MODEL
==
""
:
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
True
)
app
.
state
.
sentence_transformer_rf
=
None
else
:
app
.
state
.
sentence_transformer_rf
=
sentence_transformers
.
CrossEncoder
(
app
.
state
.
RAG_RERANKING_MODEL
,
device
=
DEVICE_TYPE
,
)
return
{
return
{
"status"
:
True
,
"status"
:
True
,
...
...
backend/apps/rag/utils.py
View file @
c9c96604
import
os
import
logging
import
logging
import
requests
import
requests
...
@@ -8,6 +9,8 @@ from apps.ollama.main import (
...
@@ -8,6 +9,8 @@ from apps.ollama.main import (
GenerateEmbeddingsForm
,
GenerateEmbeddingsForm
,
)
)
from
huggingface_hub
import
snapshot_download
from
langchain_core.documents
import
Document
from
langchain_core.documents
import
Document
from
langchain_community.retrievers
import
BM25Retriever
from
langchain_community.retrievers
import
BM25Retriever
from
langchain.retrievers
import
(
from
langchain.retrievers
import
(
...
@@ -282,8 +285,6 @@ def rag_messages(
...
@@ -282,8 +285,6 @@ def rag_messages(
extracted_collections
.
extend
(
collection
)
extracted_collections
.
extend
(
collection
)
log
.
debug
(
f
"relevant_contexts:
{
relevant_contexts
}
"
)
context_string
=
""
context_string
=
""
for
context
in
relevant_contexts
:
for
context
in
relevant_contexts
:
items
=
context
[
"documents"
][
0
]
items
=
context
[
"documents"
][
0
]
...
@@ -319,6 +320,44 @@ def rag_messages(
...
@@ -319,6 +320,44 @@ def rag_messages(
return
messages
return
messages
def
get_model_path
(
model
:
str
,
update_model
:
bool
=
False
):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir
=
os
.
getenv
(
"SENTENCE_TRANSFORMERS_HOME"
)
local_files_only
=
not
update_model
snapshot_kwargs
=
{
"cache_dir"
:
cache_dir
,
"local_files_only"
:
local_files_only
,
}
log
.
debug
(
f
"embedding_model:
{
model
}
"
)
log
.
debug
(
f
"snapshot_kwargs:
{
snapshot_kwargs
}
"
)
# Inspiration from upstream sentence_transformers
if
(
os
.
path
.
exists
(
model
)
or
(
"
\\
"
in
model
or
model
.
count
(
"/"
)
>
1
)
and
local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return
model
elif
"/"
not
in
model
:
# Set valid repo_id for model short-name
model
=
"sentence-transformers"
+
"/"
+
model
snapshot_kwargs
[
"repo_id"
]
=
model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try
:
model_repo_path
=
snapshot_download
(
**
snapshot_kwargs
)
log
.
debug
(
f
"model_repo_path:
{
model_repo_path
}
"
)
return
model_repo_path
except
Exception
as
e
:
log
.
exception
(
f
"Cannot determine model snapshot path:
{
e
}
"
)
return
model
def
generate_openai_embeddings
(
def
generate_openai_embeddings
(
model
:
str
,
text
:
str
,
key
:
str
,
url
:
str
=
"https://api.openai.com/v1"
model
:
str
,
text
:
str
,
key
:
str
,
url
:
str
=
"https://api.openai.com/v1"
):
):
...
...
backend/config.py
View file @
c9c96604
...
@@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get(
...
@@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get(
)
)
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
}
"
),
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
}
"
),
RAG_EMBEDDING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
=
(
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE
=
(
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
)
)
...
@@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
...
@@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if
not
RAG_RERANKING_MODEL
==
""
:
if
not
RAG_RERANKING_MODEL
==
""
:
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
}
"
),
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
}
"
),
RAG_RERANKING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_RERANKING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
=
(
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE
=
(
os
.
environ
.
get
(
"RAG_RERANKING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
os
.
environ
.
get
(
"RAG_RERANKING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment