Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
b48e73fa
Commit
b48e73fa
authored
Apr 14, 2024
by
Timothy J. Baek
Browse files
feat: openai embeddings support
parent
36ce1579
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
48 deletions
+121
-48
backend/apps/rag/main.py
backend/apps/rag/main.py
+98
-48
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+23
-0
No files found.
backend/apps/rag/main.py
View file @
b48e73fa
...
@@ -53,6 +53,7 @@ from apps.rag.utils import (
...
@@ -53,6 +53,7 @@ from apps.rag.utils import (
query_collection
,
query_collection
,
query_embeddings_collection
,
query_embeddings_collection
,
get_embedding_model_path
,
get_embedding_model_path
,
generate_openai_embeddings
,
)
)
from
utils.misc
import
(
from
utils.misc
import
(
...
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
...
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_OPENAI_API_BASE_URL
=
"https://api.openai.com"
app
.
state
.
RAG_OPENAI_API_KEY
=
""
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
...
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
...
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status"
:
True
,
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
RAG_OPENAI_API_KEY
,
},
}
}
class
OpenAIConfigForm
(
BaseModel
):
url
:
str
key
:
str
class
EmbeddingModelUpdateForm
(
BaseModel
):
class
EmbeddingModelUpdateForm
(
BaseModel
):
openai_config
:
Optional
[
OpenAIConfigForm
]
=
None
embedding_engine
:
str
embedding_engine
:
str
embedding_model
:
str
embedding_model
:
str
...
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
...
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
async
def
update_embedding_config
(
async
def
update_embedding_config
(
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
):
log
.
info
(
log
.
info
(
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
)
)
try
:
try
:
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]
:
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
sentence_transformer_ef
=
None
app
.
state
.
sentence_transformer_ef
=
None
if
form_data
.
openai_config
!=
None
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
RAG_OPENAI_API_KEY
=
form_data
.
openai_config
.
key
else
:
else
:
sentence_transformer_ef
=
(
sentence_transformer_ef
=
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
...
@@ -183,6 +198,10 @@ async def update_embedding_config(
...
@@ -183,6 +198,10 @@ async def update_embedding_config(
"status"
:
True
,
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
RAG_OPENAI_API_KEY
,
},
}
}
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -275,28 +294,37 @@ def query_doc_handler(
...
@@ -275,28 +294,37 @@ def query_doc_handler(
):
):
try
:
try
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
query_embeddings
=
generate_ollama_embeddings
(
return
query_doc
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
return
query_embeddings_doc
(
collection_name
=
form_data
.
collection_name
,
collection_name
=
form_data
.
collection_name
,
query
_embeddings
=
query_embeddings
,
query
=
form_data
.
query
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
else
:
else
:
return
query_doc
(
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
text
=
form_data
.
query
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
return
query_embeddings_doc
(
collection_name
=
form_data
.
collection_name
,
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
query
_embeddings
=
query_embeddings
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -317,28 +345,38 @@ def query_collection_handler(
...
@@ -317,28 +345,38 @@ def query_collection_handler(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
try
:
try
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
query_embeddings
=
generate_ollama_embeddings
(
return
query_collection
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
return
query_embeddings_collection
(
collection_names
=
form_data
.
collection_names
,
collection_names
=
form_data
.
collection_names
,
query
_embeddings
=
query_embeddings
,
query
=
form_data
.
query
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
else
:
else
:
return
query_collection
(
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
text
=
form_data
.
query
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
return
query_embeddings_collection
(
collection_names
=
form_data
.
collection_names
,
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
query
_embeddings
=
query_embeddings
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -414,39 +452,51 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
...
@@ -414,39 +452,51 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log
.
info
(
f
"deleting existing collection
{
collection_name
}
"
)
log
.
info
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
for
batch
in
create_batches
(
for
batch
in
create_batches
(
api
=
CHROMA_CLIENT
,
api
=
CHROMA_CLIENT
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
metadatas
=
metadatas
,
metadatas
=
metadatas
,
embeddings
=
[
documents
=
texts
,
):
collection
.
add
(
*
batch
)
else
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
embeddings
=
[
generate_ollama_embeddings
(
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
GenerateEmbeddingsForm
(
**
{
"model"
:
RAG_EMBEDDING_MODEL
,
"prompt"
:
text
}
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
text
}
)
)
)
)
for
text
in
texts
for
text
in
texts
],
]
):
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
collection
.
add
(
*
batch
)
embeddings
=
[
else
:
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
collection
=
CHROMA_CLIENT
.
create_collection
(
text
=
text
,
name
=
collection_name
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
)
for
text
in
texts
]
for
batch
in
create_batches
(
for
batch
in
create_batches
(
api
=
CHROMA_CLIENT
,
api
=
CHROMA_CLIENT
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
metadatas
=
metadatas
,
metadatas
=
metadatas
,
documents
=
text
s
,
embeddings
=
embedding
s
,
):
):
collection
.
add
(
*
batch
)
collection
.
add
(
*
batch
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
...
...
backend/apps/rag/utils.py
View file @
b48e73fa
...
@@ -269,3 +269,26 @@ def get_embedding_model_path(
...
@@ -269,3 +269,26 @@ def get_embedding_model_path(
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
f
"Cannot determine embedding model snapshot path:
{
e
}
"
)
log
.
exception
(
f
"Cannot determine embedding model snapshot path:
{
e
}
"
)
return
embedding_model
return
embedding_model
def
generate_openai_embeddings
(
model
:
str
,
text
:
str
,
key
:
str
,
url
:
str
=
"https://api.openai.com"
):
try
:
r
=
requests
.
post
(
f
"
{
url
}
/v1/embeddings"
,
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
key
}
"
,
},
json
=
{
"input"
:
text
,
"model"
:
model
},
)
r
.
raise_for_status
()
data
=
r
.
json
()
if
"data"
in
data
:
return
data
[
"data"
][
0
][
"embedding"
]
else
:
raise
"Something went wrong :/"
except
Exception
as
e
:
print
(
e
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment