Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
54a4b7db
Unverified
Commit
54a4b7db
authored
Apr 14, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Apr 14, 2024
Browse files
Merge pull request #1554 from open-webui/external-embeddings
feat: external embeddings
parents
2e0def73
741ed5dc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
274 additions
and
87 deletions
+274
-87
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+2
-2
backend/apps/rag/main.py
backend/apps/rag/main.py
+102
-49
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+93
-18
backend/main.py
backend/main.py
+4
-0
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+6
-0
src/lib/components/documents/Settings/General.svelte
src/lib/components/documents/Settings/General.svelte
+67
-18
No files found.
backend/apps/ollama/main.py
View file @
54a4b7db
...
@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
...
@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
):
):
log
.
info
(
"generate_ollama_embeddings
"
,
form_data
)
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
if
url_idx
==
None
:
if
url_idx
==
None
:
model
=
form_data
.
model
model
=
form_data
.
model
...
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
...
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
data
=
r
.
json
()
data
=
r
.
json
()
log
.
info
(
"generate_ollama_embeddings
"
,
data
)
log
.
info
(
f
"generate_ollama_embeddings
{
data
}
"
)
if
"embedding"
in
data
:
if
"embedding"
in
data
:
return
data
[
"embedding"
]
return
data
[
"embedding"
]
...
...
backend/apps/rag/main.py
View file @
54a4b7db
...
@@ -53,6 +53,7 @@ from apps.rag.utils import (
...
@@ -53,6 +53,7 @@ from apps.rag.utils import (
query_collection
,
query_collection
,
query_embeddings_collection
,
query_embeddings_collection
,
get_embedding_model_path
,
get_embedding_model_path
,
generate_openai_embeddings
,
)
)
from
utils.misc
import
(
from
utils.misc
import
(
...
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
...
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_OPENAI_API_BASE_URL
=
"https://api.openai.com"
app
.
state
.
RAG_OPENAI_API_KEY
=
""
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
app
.
state
.
PDF_EXTRACT_IMAGES
=
False
...
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
...
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status"
:
True
,
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
RAG_OPENAI_API_KEY
,
},
}
}
class
OpenAIConfigForm
(
BaseModel
):
url
:
str
key
:
str
class
EmbeddingModelUpdateForm
(
BaseModel
):
class
EmbeddingModelUpdateForm
(
BaseModel
):
openai_config
:
Optional
[
OpenAIConfigForm
]
=
None
embedding_engine
:
str
embedding_engine
:
str
embedding_model
:
str
embedding_model
:
str
...
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
...
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
async
def
update_embedding_config
(
async
def
update_embedding_config
(
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
):
log
.
info
(
log
.
info
(
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
)
)
try
:
try
:
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]
:
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
sentence_transformer_ef
=
None
app
.
state
.
sentence_transformer_ef
=
None
if
form_data
.
openai_config
!=
None
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
RAG_OPENAI_API_KEY
=
form_data
.
openai_config
.
key
else
:
else
:
sentence_transformer_ef
=
(
sentence_transformer_ef
=
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
...
@@ -183,6 +198,10 @@ async def update_embedding_config(
...
@@ -183,6 +198,10 @@ async def update_embedding_config(
"status"
:
True
,
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
RAG_OPENAI_API_KEY
,
},
}
}
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -275,28 +294,37 @@ def query_doc_handler(
...
@@ -275,28 +294,37 @@ def query_doc_handler(
):
):
try
:
try
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
query_embeddings
=
generate_ollama_embeddings
(
return
query_doc
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
return
query_embeddings_doc
(
collection_name
=
form_data
.
collection_name
,
collection_name
=
form_data
.
collection_name
,
query
_embeddings
=
query_embeddings
,
query
=
form_data
.
query
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
else
:
else
:
return
query_doc
(
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
text
=
form_data
.
query
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
return
query_embeddings_doc
(
collection_name
=
form_data
.
collection_name
,
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
query
_embeddings
=
query_embeddings
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -317,28 +345,38 @@ def query_collection_handler(
...
@@ -317,28 +345,38 @@ def query_collection_handler(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
try
:
try
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
query_embeddings
=
generate_ollama_embeddings
(
return
query_collection
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
return
query_embeddings_collection
(
collection_names
=
form_data
.
collection_names
,
collection_names
=
form_data
.
collection_names
,
query
_embeddings
=
query_embeddings
,
query
=
form_data
.
query
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
else
:
else
:
return
query_collection
(
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
form_data
.
query
,
}
)
)
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
text
=
form_data
.
query
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
return
query_embeddings_collection
(
collection_names
=
form_data
.
collection_names
,
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
query
_embeddings
=
query_embeddings
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
...
@@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
...
@@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
docs
=
text_splitter
.
split_documents
(
data
)
docs
=
text_splitter
.
split_documents
(
data
)
if
len
(
docs
)
>
0
:
if
len
(
docs
)
>
0
:
log
.
info
(
"store_data_in_vector_db
"
,
"store_docs_in_vector_db
"
)
log
.
info
(
f
"store_data_in_vector_db
{
docs
}
"
)
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
),
None
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
),
None
else
:
else
:
raise
ValueError
(
ERROR_MESSAGES
.
EMPTY_CONTENT
)
raise
ValueError
(
ERROR_MESSAGES
.
EMPTY_CONTENT
)
...
@@ -402,7 +440,7 @@ def store_text_in_vector_db(
...
@@ -402,7 +440,7 @@ def store_text_in_vector_db(
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
log
.
info
(
"store_docs_in_vector_db
"
,
docs
,
collection_name
)
log
.
info
(
f
"store_docs_in_vector_db
{
docs
}
{
collection_name
}
"
)
texts
=
[
doc
.
page_content
for
doc
in
docs
]
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
...
@@ -414,39 +452,54 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
...
@@ -414,39 +452,54 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log
.
info
(
f
"deleting existing collection
{
collection_name
}
"
)
log
.
info
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
for
batch
in
create_batches
(
for
batch
in
create_batches
(
api
=
CHROMA_CLIENT
,
api
=
CHROMA_CLIENT
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
metadatas
=
metadatas
,
metadatas
=
metadatas
,
embeddings
=
[
documents
=
texts
,
):
collection
.
add
(
*
batch
)
else
:
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
embeddings
=
[
generate_ollama_embeddings
(
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
GenerateEmbeddingsForm
(
**
{
"model"
:
RAG_EMBEDDING_MODEL
,
"prompt"
:
text
}
**
{
"model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"prompt"
:
text
}
)
)
)
)
for
text
in
texts
for
text
in
texts
],
]
):
elif
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"openai"
:
collection
.
add
(
*
batch
)
embeddings
=
[
else
:
generate_openai_embeddings
(
model
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
collection
=
CHROMA_CLIENT
.
create_collection
(
text
=
text
,
name
=
collection_name
,
key
=
app
.
state
.
RAG_OPENAI_API_KEY
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
url
=
app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
)
for
text
in
texts
]
for
batch
in
create_batches
(
for
batch
in
create_batches
(
api
=
CHROMA_CLIENT
,
api
=
CHROMA_CLIENT
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
],
metadatas
=
metadatas
,
metadatas
=
metadatas
,
embeddings
=
embeddings
,
documents
=
texts
,
documents
=
texts
,
):
):
collection
.
add
(
*
batch
)
collection
.
add
(
*
batch
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
...
...
backend/apps/rag/utils.py
View file @
54a4b7db
...
@@ -6,9 +6,12 @@ import requests
...
@@ -6,9 +6,12 @@ import requests
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
apps.ollama.main
import
generate_ollama_embeddings
,
GenerateEmbeddingsForm
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
...
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
...
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
def
query_embeddings_doc
(
collection_name
:
str
,
query_embeddings
,
k
:
int
):
def
query_embeddings_doc
(
collection_name
:
str
,
query_embeddings
,
k
:
int
):
try
:
try
:
# if you use docker use the model from the environment variable
# if you use docker use the model from the environment variable
log
.
info
(
"query_embeddings_doc
"
,
query_embeddings
)
log
.
info
(
f
"query_embeddings_doc
{
query_embeddings
}
"
)
collection
=
CHROMA_CLIENT
.
get_collection
(
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
,
name
=
collection_name
,
)
)
...
@@ -40,6 +43,8 @@ def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
...
@@ -40,6 +43,8 @@ def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
query_embeddings
=
[
query_embeddings
],
query_embeddings
=
[
query_embeddings
],
n_results
=
k
,
n_results
=
k
,
)
)
log
.
info
(
f
"query_embeddings_doc:result
{
result
}
"
)
return
result
return
result
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
...
@@ -118,7 +123,7 @@ def query_collection(
...
@@ -118,7 +123,7 @@ def query_collection(
def
query_embeddings_collection
(
collection_names
:
List
[
str
],
query_embeddings
,
k
:
int
):
def
query_embeddings_collection
(
collection_names
:
List
[
str
],
query_embeddings
,
k
:
int
):
results
=
[]
results
=
[]
log
.
info
(
"query_embeddings_collection
"
,
query_embeddings
)
log
.
info
(
f
"query_embeddings_collection
{
query_embeddings
}
"
)
for
collection_name
in
collection_names
:
for
collection_name
in
collection_names
:
try
:
try
:
...
@@ -141,8 +146,20 @@ def rag_template(template: str, context: str, query: str):
...
@@ -141,8 +146,20 @@ def rag_template(template: str, context: str, query: str):
return
template
return
template
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
def
rag_messages
(
log
.
debug
(
f
"docs:
{
docs
}
"
)
docs
,
messages
,
template
,
k
,
embedding_engine
,
embedding_model
,
embedding_function
,
openai_key
,
openai_url
,
):
log
.
debug
(
f
"docs:
{
docs
}
{
messages
}
{
embedding_engine
}
{
embedding_model
}
{
embedding_function
}
{
openai_key
}
{
openai_url
}
"
)
last_user_message_idx
=
None
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
...
@@ -175,22 +192,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -175,22 +192,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
context
=
None
context
=
None
try
:
try
:
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
if
doc
[
"type"
]
==
"text"
:
collection_names
=
doc
[
"collection_names"
],
query
=
query
,
k
=
k
,
embedding_function
=
embedding_function
,
)
elif
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
context
=
doc
[
"content"
]
else
:
else
:
context
=
query_doc
(
if
embedding_engine
==
""
:
collection_name
=
doc
[
"collection_name"
],
if
doc
[
"type"
]
==
"collection"
:
query
=
query
,
context
=
query_collection
(
k
=
k
,
collection_names
=
doc
[
"collection_names"
],
embedding_function
=
embedding_function
,
query
=
query
,
)
k
=
k
,
embedding_function
=
embedding_function
,
)
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
query
=
query
,
k
=
k
,
embedding_function
=
embedding_function
,
)
else
:
if
embedding_engine
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
embedding_model
,
"prompt"
:
query
,
}
)
)
elif
embedding_engine
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
embedding_model
,
text
=
query
,
key
=
openai_key
,
url
=
openai_url
,
)
if
doc
[
"type"
]
==
"collection"
:
context
=
query_embeddings_collection
(
collection_names
=
doc
[
"collection_names"
],
query_embeddings
=
query_embeddings
,
k
=
k
,
)
else
:
context
=
query_embeddings_doc
(
collection_name
=
doc
[
"collection_name"
],
query_embeddings
=
query_embeddings
,
k
=
k
,
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
context
=
None
context
=
None
...
@@ -269,3 +321,26 @@ def get_embedding_model_path(
...
@@ -269,3 +321,26 @@ def get_embedding_model_path(
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
f
"Cannot determine embedding model snapshot path:
{
e
}
"
)
log
.
exception
(
f
"Cannot determine embedding model snapshot path:
{
e
}
"
)
return
embedding_model
return
embedding_model
def
generate_openai_embeddings
(
model
:
str
,
text
:
str
,
key
:
str
,
url
:
str
=
"https://api.openai.com"
):
try
:
r
=
requests
.
post
(
f
"
{
url
}
/v1/embeddings"
,
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
key
}
"
,
},
json
=
{
"input"
:
text
,
"model"
:
model
},
)
r
.
raise_for_status
()
data
=
r
.
json
()
if
"data"
in
data
:
return
data
[
"data"
][
0
][
"embedding"
]
else
:
raise
"Something went wrong :/"
except
Exception
as
e
:
print
(
e
)
return
None
backend/main.py
View file @
54a4b7db
...
@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
data
[
"messages"
],
data
[
"messages"
],
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
RAG_EMBEDDING_ENGINE
,
rag_app
.
state
.
RAG_EMBEDDING_MODEL
,
rag_app
.
state
.
sentence_transformer_ef
,
rag_app
.
state
.
sentence_transformer_ef
,
rag_app
.
state
.
RAG_OPENAI_API_KEY
,
rag_app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
)
del
data
[
"docs"
]
del
data
[
"docs"
]
...
...
src/lib/apis/rag/index.ts
View file @
54a4b7db
...
@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
...
@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
return
res
;
return
res
;
};
};
type
OpenAIConfigForm
=
{
key
:
string
;
url
:
string
;
};
type
EmbeddingModelUpdateForm
=
{
type
EmbeddingModelUpdateForm
=
{
openai_config
?:
OpenAIConfigForm
;
embedding_engine
:
string
;
embedding_engine
:
string
;
embedding_model
:
string
;
embedding_model
:
string
;
};
};
...
...
src/lib/components/documents/Settings/General.svelte
View file @
54a4b7db
...
@@ -29,6 +29,9 @@
...
@@ -29,6 +29,9 @@
let embeddingEngine = '';
let embeddingEngine = '';
let embeddingModel = '';
let embeddingModel = '';
let openAIKey = '';
let openAIUrl = '';
let chunkSize = 0;
let chunkSize = 0;
let chunkOverlap = 0;
let chunkOverlap = 0;
let pdfExtractImages = true;
let pdfExtractImages = true;
...
@@ -50,7 +53,15 @@
...
@@ -50,7 +53,15 @@
};
};
const embeddingModelUpdateHandler = async () => {
const embeddingModelUpdateHandler = async () => {
if (embeddingModel === '') {
if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
toast.error(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
)
);
return;
}
if (embeddingEngine === 'ollama' && embeddingModel === '') {
toast.error(
toast.error(
$i18n.t(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
...
@@ -59,7 +70,7 @@
...
@@ -59,7 +70,7 @@
return;
return;
}
}
if (embeddingEngine === '' && embeddingModel
.split('/').length - 1 > 1
) {
if (embeddingEngine === '
openai
' && embeddingModel
=== ''
) {
toast.error(
toast.error(
$i18n.t(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
...
@@ -68,20 +79,28 @@
...
@@ -68,20 +79,28 @@
return;
return;
}
}
if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
console.log('Update embedding model attempt:', embeddingModel);
console.log('Update embedding model attempt:', embeddingModel);
updateEmbeddingModelLoading = true;
updateEmbeddingModelLoading = true;
const res = await updateEmbeddingConfig(localStorage.token, {
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_engine: embeddingEngine,
embedding_model: embeddingModel
embedding_model: embeddingModel,
...(embeddingEngine === 'openai'
? {
openai_config: {
key: openAIKey,
url: openAIUrl
}
}
: {})
}).catch(async (error) => {
}).catch(async (error) => {
toast.error(error);
toast.error(error);
await setEmbeddingConfig();
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
return null;
return null;
});
});
updateEmbeddingModelLoading = false;
updateEmbeddingModelLoading = false;
...
@@ -89,7 +108,7 @@
...
@@ -89,7 +108,7 @@
if (res) {
if (res) {
console.log('embeddingModelUpdateHandler:', res);
console.log('embeddingModelUpdateHandler:', res);
if (res.status === true) {
if (res.status === true) {
toast.success($i18n.t('
Model
{{embedding_model}}
update complete!
', res), {
toast.success($i18n.t('
Embedding model set to "
{{embedding_model}}
"
', res), {
duration: 1000 * 10
duration: 1000 * 10
});
});
}
}
...
@@ -107,6 +126,18 @@
...
@@ -107,6 +126,18 @@
querySettings = await updateQuerySettings(localStorage.token, querySettings);
querySettings = await updateQuerySettings(localStorage.token, querySettings);
};
};
const setEmbeddingConfig = async () => {
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
openAIKey = embeddingConfig.openai_config.key;
openAIUrl = embeddingConfig.openai_config.url;
}
};
onMount(async () => {
onMount(async () => {
const res = await getRAGConfig(localStorage.token);
const res = await getRAGConfig(localStorage.token);
...
@@ -117,12 +148,7 @@
...
@@ -117,12 +148,7 @@
chunkOverlap = res.chunk.chunk_overlap;
chunkOverlap = res.chunk.chunk_overlap;
}
}
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
await setEmbeddingConfig();
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
querySettings = await getQuerySettings(localStorage.token);
querySettings = await getQuerySettings(localStorage.token);
});
});
...
@@ -146,15 +172,38 @@
...
@@ -146,15 +172,38 @@
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={embeddingEngine}
bind:value={embeddingEngine}
placeholder="Select an embedding engine"
placeholder="Select an embedding engine"
on:change={() => {
on:change={(e) => {
embeddingModel = '';
if (e.target.value === 'ollama') {
embeddingModel = '';
} else if (e.target.value === 'openai') {
embeddingModel = 'text-embedding-3-small';
}
}}
}}
>
>
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option>
</select>
</select>
</div>
</div>
</div>
</div>
{#if embeddingEngine === 'openai'}
<div class="mt-1 flex gap-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={openAIUrl}
required
/>
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Key')}
bind:value={openAIKey}
required
/>
</div>
{/if}
</div>
</div>
<div class="space-y-2">
<div class="space-y-2">
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment