Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
b1b72441
Commit
b1b72441
authored
Apr 14, 2024
by
Timothy J. Baek
Browse files
feat: openai embeddings integration
parent
b48e73fa
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
148 additions
and
39 deletions
+148
-39
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+2
-2
backend/apps/rag/main.py
backend/apps/rag/main.py
+4
-2
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+65
-17
backend/main.py
backend/main.py
+4
-0
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+6
-0
src/lib/components/documents/Settings/General.svelte
src/lib/components/documents/Settings/General.svelte
+67
-18
No files found.
backend/apps/ollama/main.py
View file @
b1b72441
...
@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
...
@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
):
):
log
.
info
(
"generate_ollama_embeddings
"
,
form_data
)
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
if
url_idx
==
None
:
if
url_idx
==
None
:
model
=
form_data
.
model
model
=
form_data
.
model
...
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
...
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
data
=
r
.
json
()
data
=
r
.
json
()
log
.
info
(
"generate_ollama_embeddings
"
,
data
)
log
.
info
(
f
"generate_ollama_embeddings
{
data
}
"
)
if
"embedding"
in
data
:
if
"embedding"
in
data
:
return
data
[
"embedding"
]
return
data
[
"embedding"
]
...
...
backend/apps/rag/main.py
View file @
b1b72441
...
@@ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
...
@@ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
docs
=
text_splitter
.
split_documents
(
data
)
docs
=
text_splitter
.
split_documents
(
data
)
if
len
(
docs
)
>
0
:
if
len
(
docs
)
>
0
:
log
.
info
(
"store_data_in_vector_db
"
,
"store_docs_in_vector_db
"
)
log
.
info
(
f
"store_data_in_vector_db
{
docs
}
"
)
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
),
None
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
),
None
else
:
else
:
raise
ValueError
(
ERROR_MESSAGES
.
EMPTY_CONTENT
)
raise
ValueError
(
ERROR_MESSAGES
.
EMPTY_CONTENT
)
...
@@ -440,7 +440,7 @@ def store_text_in_vector_db(
...
@@ -440,7 +440,7 @@ def store_text_in_vector_db(
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
log
.
info
(
"store_docs_in_vector_db
"
,
docs
,
collection_name
)
log
.
info
(
f
"store_docs_in_vector_db
{
docs
}
{
collection_name
}
"
)
texts
=
[
doc
.
page_content
for
doc
in
docs
]
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
...
@@ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
...
@@ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection
.
add
(
*
batch
)
collection
.
add
(
*
batch
)
else
:
else
:
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
if
app
.
state
.
RAG_EMBEDDING_ENGINE
==
"ollama"
:
embeddings
=
[
embeddings
=
[
generate_ollama_embeddings
(
generate_ollama_embeddings
(
...
...
backend/apps/rag/utils.py
View file @
b1b72441
...
@@ -6,9 +6,12 @@ import requests
...
@@ -6,9 +6,12 @@ import requests
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
apps.ollama.main
import
generate_ollama_embeddings
,
GenerateEmbeddingsForm
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
...
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
...
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
def
query_embeddings_doc
(
collection_name
:
str
,
query_embeddings
,
k
:
int
):
def
query_embeddings_doc
(
collection_name
:
str
,
query_embeddings
,
k
:
int
):
try
:
try
:
# if you use docker use the model from the environment variable
# if you use docker use the model from the environment variable
log
.
info
(
"query_embeddings_doc
"
,
query_embeddings
)
log
.
info
(
f
"query_embeddings_doc
{
query_embeddings
}
"
)
collection
=
CHROMA_CLIENT
.
get_collection
(
collection
=
CHROMA_CLIENT
.
get_collection
(
name
=
collection_name
,
name
=
collection_name
,
)
)
...
@@ -118,7 +121,7 @@ def query_collection(
...
@@ -118,7 +121,7 @@ def query_collection(
def
query_embeddings_collection
(
collection_names
:
List
[
str
],
query_embeddings
,
k
:
int
):
def
query_embeddings_collection
(
collection_names
:
List
[
str
],
query_embeddings
,
k
:
int
):
results
=
[]
results
=
[]
log
.
info
(
"query_embeddings_collection
"
,
query_embeddings
)
log
.
info
(
f
"query_embeddings_collection
{
query_embeddings
}
"
)
for
collection_name
in
collection_names
:
for
collection_name
in
collection_names
:
try
:
try
:
...
@@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
...
@@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
return
template
return
template
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_engine
,
embedding_model
,
embedding_function
,
openai_key
,
openai_url
,
):
log
.
debug
(
f
"docs:
{
docs
}
"
)
log
.
debug
(
f
"docs:
{
docs
}
"
)
last_user_message_idx
=
None
last_user_message_idx
=
None
...
@@ -175,6 +188,11 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -175,6 +188,11 @@ def rag_messages(docs, messages, template, k, embedding_function):
context
=
None
context
=
None
try
:
try
:
if
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
else
:
if
embedding_engine
==
""
:
if
doc
[
"type"
]
==
"collection"
:
if
doc
[
"type"
]
==
"collection"
:
context
=
query_collection
(
context
=
query_collection
(
collection_names
=
doc
[
"collection_names"
],
collection_names
=
doc
[
"collection_names"
],
...
@@ -182,8 +200,6 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -182,8 +200,6 @@ def rag_messages(docs, messages, template, k, embedding_function):
k
=
k
,
k
=
k
,
embedding_function
=
embedding_function
,
embedding_function
=
embedding_function
,
)
)
elif
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
else
:
else
:
context
=
query_doc
(
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
collection_name
=
doc
[
"collection_name"
],
...
@@ -191,6 +207,38 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -191,6 +207,38 @@ def rag_messages(docs, messages, template, k, embedding_function):
k
=
k
,
k
=
k
,
embedding_function
=
embedding_function
,
embedding_function
=
embedding_function
,
)
)
else
:
if
embedding_engine
==
"ollama"
:
query_embeddings
=
generate_ollama_embeddings
(
GenerateEmbeddingsForm
(
**
{
"model"
:
embedding_model
,
"prompt"
:
query
,
}
)
)
elif
embedding_engine
==
"openai"
:
query_embeddings
=
generate_openai_embeddings
(
model
=
embedding_model
,
text
=
query
,
key
=
openai_key
,
url
=
openai_url
,
)
if
doc
[
"type"
]
==
"collection"
:
context
=
query_embeddings_collection
(
collection_names
=
doc
[
"collection_names"
],
query_embeddings
=
query_embeddings
,
k
=
k
,
)
else
:
context
=
query_embeddings_doc
(
collection_name
=
doc
[
"collection_name"
],
query_embeddings
=
query_embeddings
,
k
=
k
,
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
context
=
None
context
=
None
...
...
backend/main.py
View file @
b1b72441
...
@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
data
[
"messages"
],
data
[
"messages"
],
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
RAG_TEMPLATE
,
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
TOP_K
,
rag_app
.
state
.
RAG_EMBEDDING_ENGINE
,
rag_app
.
state
.
RAG_EMBEDDING_MODEL
,
rag_app
.
state
.
sentence_transformer_ef
,
rag_app
.
state
.
sentence_transformer_ef
,
rag_app
.
state
.
RAG_OPENAI_API_KEY
,
rag_app
.
state
.
RAG_OPENAI_API_BASE_URL
,
)
)
del
data
[
"docs"
]
del
data
[
"docs"
]
...
...
src/lib/apis/rag/index.ts
View file @
b1b72441
...
@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
...
@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
return
res
;
return
res
;
};
};
type
OpenAIConfigForm
=
{
key
:
string
;
url
:
string
;
};
type
EmbeddingModelUpdateForm
=
{
type
EmbeddingModelUpdateForm
=
{
openai_config
?:
OpenAIConfigForm
;
embedding_engine
:
string
;
embedding_engine
:
string
;
embedding_model
:
string
;
embedding_model
:
string
;
};
};
...
...
src/lib/components/documents/Settings/General.svelte
View file @
b1b72441
...
@@ -29,6 +29,9 @@
...
@@ -29,6 +29,9 @@
let embeddingEngine = '';
let embeddingEngine = '';
let embeddingModel = '';
let embeddingModel = '';
let openAIKey = '';
let openAIUrl = '';
let chunkSize = 0;
let chunkSize = 0;
let chunkOverlap = 0;
let chunkOverlap = 0;
let pdfExtractImages = true;
let pdfExtractImages = true;
...
@@ -50,7 +53,15 @@
...
@@ -50,7 +53,15 @@
};
};
const embeddingModelUpdateHandler = async () => {
const embeddingModelUpdateHandler = async () => {
if (embeddingModel === '') {
if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
toast.error(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
)
);
return;
}
if (embeddingEngine === 'ollama' && embeddingModel === '') {
toast.error(
toast.error(
$i18n.t(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
...
@@ -59,7 +70,7 @@
...
@@ -59,7 +70,7 @@
return;
return;
}
}
if (embeddingEngine === '' && embeddingModel
.split('/').length - 1 > 1
) {
if (embeddingEngine === '
openai
' && embeddingModel
=== ''
) {
toast.error(
toast.error(
$i18n.t(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
...
@@ -68,20 +79,28 @@
...
@@ -68,20 +79,28 @@
return;
return;
}
}
if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
console.log('Update embedding model attempt:', embeddingModel);
console.log('Update embedding model attempt:', embeddingModel);
updateEmbeddingModelLoading = true;
updateEmbeddingModelLoading = true;
const res = await updateEmbeddingConfig(localStorage.token, {
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_engine: embeddingEngine,
embedding_model: embeddingModel
embedding_model: embeddingModel,
...(embeddingEngine === 'openai'
? {
openai_config: {
key: openAIKey,
url: openAIUrl
}
}
: {})
}).catch(async (error) => {
}).catch(async (error) => {
toast.error(error);
toast.error(error);
await setEmbeddingConfig();
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
return null;
return null;
});
});
updateEmbeddingModelLoading = false;
updateEmbeddingModelLoading = false;
...
@@ -89,7 +108,7 @@
...
@@ -89,7 +108,7 @@
if (res) {
if (res) {
console.log('embeddingModelUpdateHandler:', res);
console.log('embeddingModelUpdateHandler:', res);
if (res.status === true) {
if (res.status === true) {
toast.success($i18n.t('
Model
{{embedding_model}}
update complete!
', res), {
toast.success($i18n.t('
Embedding model set to "
{{embedding_model}}
"
', res), {
duration: 1000 * 10
duration: 1000 * 10
});
});
}
}
...
@@ -107,6 +126,18 @@
...
@@ -107,6 +126,18 @@
querySettings = await updateQuerySettings(localStorage.token, querySettings);
querySettings = await updateQuerySettings(localStorage.token, querySettings);
};
};
const setEmbeddingConfig = async () => {
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
openAIKey = embeddingConfig.openai_config.key;
openAIUrl = embeddingConfig.openai_config.url;
}
};
onMount(async () => {
onMount(async () => {
const res = await getRAGConfig(localStorage.token);
const res = await getRAGConfig(localStorage.token);
...
@@ -117,12 +148,7 @@
...
@@ -117,12 +148,7 @@
chunkOverlap = res.chunk.chunk_overlap;
chunkOverlap = res.chunk.chunk_overlap;
}
}
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
await setEmbeddingConfig();
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
querySettings = await getQuerySettings(localStorage.token);
querySettings = await getQuerySettings(localStorage.token);
});
});
...
@@ -146,15 +172,38 @@
...
@@ -146,15 +172,38 @@
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={embeddingEngine}
bind:value={embeddingEngine}
placeholder="Select an embedding engine"
placeholder="Select an embedding engine"
on:change={() => {
on:change={(e) => {
if (e.target.value === 'ollama') {
embeddingModel = '';
embeddingModel = '';
} else if (e.target.value === 'openai') {
embeddingModel = 'text-embedding-3-small';
}
}}
}}
>
>
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option>
</select>
</select>
</div>
</div>
</div>
</div>
{#if embeddingEngine === 'openai'}
<div class="mt-1 flex gap-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={openAIUrl}
required
/>
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Key')}
bind:value={openAIKey}
required
/>
</div>
{/if}
</div>
</div>
<div class="space-y-2">
<div class="space-y-2">
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment