Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
7c127c35
Commit
7c127c35
authored
Feb 19, 2024
by
Timothy J. Baek
Browse files
feat: dynamic embedding model load
parent
ab104d59
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
36 deletions
+56
-36
backend/apps/rag/main.py
backend/apps/rag/main.py
+56
-36
No files found.
backend/apps/rag/main.py
View file @
7c127c35
...
...
@@ -35,6 +35,8 @@ from pydantic import BaseModel
from
typing
import
Optional
import
mimetypes
import
uuid
import
json
from
apps.web.models.documents
import
(
Documents
,
...
...
@@ -63,24 +65,26 @@ from config import (
from
constants
import
ERROR_MESSAGES
#
#if RAG_EMBEDDING_MODEL:
#
if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
if
RAG_EMBEDDING_MODEL
:
sentence_transformer_ef
=
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
model_name
=
RAG_EMBEDDING_MODEL
,
device
=
RAG_EMBEDDING_MODEL_DEVICE_TYPE
,
)
app
=
FastAPI
()
app
.
state
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
sentence_transformer_ef
=
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
model_name
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
device
=
RAG_EMBEDDING_MODEL_DEVICE_TYPE
,
)
)
origins
=
[
"*"
]
...
...
@@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
try
:
if
RAG_EMBEDDING_MODEL
:
# if you use docker use the model from the environment variable
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
sentence_transformer_ef
)
else
:
# for local development use the default model
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
collection
.
add
(
documents
=
texts
,
metadatas
=
metadatas
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
]
...
...
@@ -139,6 +139,38 @@ async def get_status():
"status"
:
True
,
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
}
@
app
.
get
(
"/embedding/model"
)
async
def
get_embedding_model
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
}
class
EmbeddingModelUpdateForm
(
BaseModel
):
embedding_model
:
str
@
app
.
post
(
"/embedding/model/update"
)
async
def
update_embedding_model
(
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
sentence_transformer_ef
=
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
model_name
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
device
=
RAG_EMBEDDING_MODEL_DEVICE_TYPE
,
)
)
return
{
"status"
:
True
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
}
...
...
@@ -203,17 +235,11 @@ def query_doc(
user
=
Depends
(
get_current_user
),
):
try
:
if
RAG_EMBEDDING_MODEL
:
# if you use docker use the model from the environment variable
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
form_data
.
collection_name
,
embedding_function
=
sentence_transformer_ef
,
)
else
:
# for local development use the default model
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
form_data
.
collection_name
,
)
# if you use docker use the model from the environment variable
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
form_data
.
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
result
=
collection
.
query
(
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
)
return
result
except
Exception
as
e
:
...
...
@@ -284,17 +310,11 @@ def query_collection(
for
collection_name
in
form_data
.
collection_names
:
try
:
if
RAG_EMBEDDING_MODEL
:
# if you use docker use the model from the environment variable
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
,
embedding_function
=
sentence_transformer_ef
,
)
else
:
# for local development use the default model
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
,
)
# if you use docker use the model from the environment variable
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
result
=
collection
.
query
(
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment