Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
47a05a47
"...locales/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "64ae3d733668613fde719dede468e85d8b0417c0"
Commit
47a05a47
authored
Mar 02, 2024
by
Timothy J. Baek
Browse files
feat: add rag top k value setting
parent
9694c656
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
123 additions
and
38 deletions
+123
-38
backend/apps/rag/main.py
backend/apps/rag/main.py
+34
-14
src/lib/apis/rag/index.ts
src/lib/apis/rag/index.ts
+36
-4
src/lib/components/documents/Settings/General.svelte
src/lib/components/documents/Settings/General.svelte
+43
-6
src/routes/(app)/+page.svelte
src/routes/(app)/+page.svelte
+5
-7
src/routes/(app)/c/[id]/+page.svelte
src/routes/(app)/c/[id]/+page.svelte
+5
-7
No files found.
backend/apps/rag/main.py
View file @
47a05a47
...
@@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
...
@@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
TOP_K
=
4
app
.
state
.
sentence_transformer_ef
=
(
app
.
state
.
sentence_transformer_ef
=
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
embedding_functions
.
SentenceTransformerEmbeddingFunction
(
model_name
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
model_name
=
app
.
state
.
RAG_EMBEDDING_MODEL
,
...
@@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
...
@@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
}
}
class
RAGTemplateForm
(
BaseModel
):
@
app
.
get
(
"/query/settings"
)
template
:
str
async
def
get_query_settings
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
TOP_K
,
}
@
app
.
post
(
"/template/update"
)
class
QuerySettingsForm
(
BaseModel
):
async
def
update_rag_template
(
form_data
:
RAGTemplateForm
,
user
=
Depends
(
get_admin_user
)):
k
:
Optional
[
int
]
=
None
# TODO: check template requirements
template
:
Optional
[
str
]
=
None
app
.
state
.
RAG_TEMPLATE
=
(
form_data
.
template
if
form_data
.
template
!=
""
else
RAG_TEMPLATE
)
@
app
.
post
(
"/query/settings/update"
)
async
def
update_query_settings
(
form_data
:
QuerySettingsForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
RAG_TEMPLATE
=
form_data
.
template
if
form_data
.
template
else
RAG_TEMPLATE
app
.
state
.
TOP_K
=
form_data
.
k
if
form_data
.
k
else
4
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
}
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
}
class
QueryDocForm
(
BaseModel
):
class
QueryDocForm
(
BaseModel
):
collection_name
:
str
collection_name
:
str
query
:
str
query
:
str
k
:
Optional
[
int
]
=
4
k
:
Optional
[
int
]
=
None
@
app
.
post
(
"/query/doc"
)
@
app
.
post
(
"/query/doc"
)
...
@@ -240,7 +252,10 @@ def query_doc(
...
@@ -240,7 +252,10 @@ def query_doc(
name
=
form_data
.
collection_name
,
name
=
form_data
.
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
result
=
collection
.
query
(
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
)
result
=
collection
.
query
(
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
)
return
result
return
result
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
...
@@ -253,7 +268,7 @@ def query_doc(
...
@@ -253,7 +268,7 @@ def query_doc(
class
QueryCollectionsForm
(
BaseModel
):
class
QueryCollectionsForm
(
BaseModel
):
collection_names
:
List
[
str
]
collection_names
:
List
[
str
]
query
:
str
query
:
str
k
:
Optional
[
int
]
=
4
k
:
Optional
[
int
]
=
None
def
merge_and_sort_query_results
(
query_results
,
k
):
def
merge_and_sort_query_results
(
query_results
,
k
):
...
@@ -317,13 +332,16 @@ def query_collection(
...
@@ -317,13 +332,16 @@ def query_collection(
)
)
result
=
collection
.
query
(
result
=
collection
.
query
(
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
query_texts
=
[
form_data
.
query
],
n_results
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
)
)
results
.
append
(
result
)
results
.
append
(
result
)
except
:
except
:
pass
pass
return
merge_and_sort_query_results
(
results
,
form_data
.
k
)
return
merge_and_sort_query_results
(
results
,
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
)
@
app
.
post
(
"/web"
)
@
app
.
post
(
"/web"
)
...
@@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
...
@@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
,
]
or
file_ext
in
[
"xls"
,
"xlsx"
]:
]
or
file_ext
in
[
"xls"
,
"xlsx"
]:
loader
=
UnstructuredExcelLoader
(
file_path
)
loader
=
UnstructuredExcelLoader
(
file_path
)
elif
file_ext
in
known_source_ext
or
(
file_content_type
and
file_content_type
.
find
(
"text/"
)
>=
0
):
elif
file_ext
in
known_source_ext
or
(
file_content_type
and
file_content_type
.
find
(
"text/"
)
>=
0
):
loader
=
TextLoader
(
file_path
)
loader
=
TextLoader
(
file_path
)
else
:
else
:
loader
=
TextLoader
(
file_path
)
loader
=
TextLoader
(
file_path
)
...
...
src/lib/apis/rag/index.ts
View file @
47a05a47
...
@@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => {
...
@@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => {
return
res
?.
template
??
''
;
return
res
?.
template
??
''
;
};
};
export
const
updateRAGTemplate
=
async
(
token
:
string
,
template
:
string
)
=>
{
export
const
getQuerySettings
=
async
(
token
:
string
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
RAG_API_BASE_URL
}
/template/update`
,
{
const
res
=
await
fetch
(
`
${
RAG_API_BASE_URL
}
/query/settings`
,
{
method
:
'
GET
'
,
headers
:
{
'
Content-Type
'
:
'
application/json
'
,
Authorization
:
`Bearer
${
token
}
`
}
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
error
=
err
.
detail
;
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
;
};
type
QuerySettings
=
{
k
:
number
|
null
;
template
:
string
|
null
;
};
export
const
updateQuerySettings
=
async
(
token
:
string
,
settings
:
QuerySettings
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
RAG_API_BASE_URL
}
/query/settings/update`
,
{
method
:
'
POST
'
,
method
:
'
POST
'
,
headers
:
{
headers
:
{
'
Content-Type
'
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
Authorization
:
`Bearer
${
token
}
`
Authorization
:
`Bearer
${
token
}
`
},
},
body
:
JSON
.
stringify
({
body
:
JSON
.
stringify
({
template
:
template
...
settings
})
})
})
})
.
then
(
async
(
res
)
=>
{
.
then
(
async
(
res
)
=>
{
...
@@ -183,7 +215,7 @@ export const queryDoc = async (
...
@@ -183,7 +215,7 @@ export const queryDoc = async (
token
:
string
,
token
:
string
,
collection_name
:
string
,
collection_name
:
string
,
query
:
string
,
query
:
string
,
k
:
number
k
:
number
|
null
=
null
)
=>
{
)
=>
{
let
error
=
null
;
let
error
=
null
;
...
...
src/lib/components/documents/Settings/General.svelte
View file @
47a05a47
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import { getDocs } from '$lib/apis/documents';
import { getDocs } from '$lib/apis/documents';
import {
import {
getChunkParams,
getChunkParams,
get
RAGTemplate
,
get
QuerySettings
,
scanDocs,
scanDocs,
updateChunkParams,
updateChunkParams,
update
RAGTemplate
update
QuerySettings
} from '$lib/apis/rag';
} from '$lib/apis/rag';
import { documents } from '$lib/stores';
import { documents } from '$lib/stores';
import { onMount } from 'svelte';
import { onMount } from 'svelte';
...
@@ -18,7 +18,10 @@
...
@@ -18,7 +18,10 @@
let chunkSize = 0;
let chunkSize = 0;
let chunkOverlap = 0;
let chunkOverlap = 0;
let template = '';
let querySettings = {
template: '',
k: 4
};
const scanHandler = async () => {
const scanHandler = async () => {
loading = true;
loading = true;
...
@@ -33,7 +36,7 @@
...
@@ -33,7 +36,7 @@
const submitHandler = async () => {
const submitHandler = async () => {
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
await updateRAGTemplate
(localStorage.token,
template
);
querySettings = await updateQuerySettings
(localStorage.token,
querySettings
);
};
};
onMount(async () => {
onMount(async () => {
...
@@ -44,7 +47,7 @@
...
@@ -44,7 +47,7 @@
chunkOverlap = res.chunk_overlap;
chunkOverlap = res.chunk_overlap;
}
}
template = await getRAGTemplate
(localStorage.token);
querySettings = await getQuerySettings
(localStorage.token);
});
});
</script>
</script>
...
@@ -156,10 +159,44 @@
...
@@ -156,10 +159,44 @@
</div>
</div>
</div>
</div>
<div class=" text-sm font-medium">Query Params</div>
<div class=" flex">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium flex-1">Top K</div>
<div class="self-center p-3">
<input
class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Top K"
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
<!-- <div class="flex w-full">
<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
<div class="self-center p-3">
<input
class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Chunk Overlap"
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div> -->
</div>
<div>
<div>
<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
<textarea
<textarea
bind:value={template}
bind:value={
querySettings.
template}
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
rows="4"
rows="4"
/>
/>
...
...
src/routes/(app)/+page.svelte
View file @
47a05a47
...
@@ -248,19 +248,17 @@
...
@@ -248,19 +248,17 @@
let
relevantContexts
=
await
Promise
.
all
(
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
,
4
).
catch
(
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
(
error
)
=>
{
console
.
log
(
error
);
console
.
log
(
error
);
return
null
;
return
null
;
}
}
);
);
}
else
{
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
,
4
).
catch
(
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
(
error
)
=>
{
console
.
log
(
error
);
console
.
log
(
error
);
return
null
;
return
null
;
});
}
);
}
}
})
})
);
);
...
...
src/routes/(app)/c/[id]/+page.svelte
View file @
47a05a47
...
@@ -261,19 +261,17 @@
...
@@ -261,19 +261,17 @@
let
relevantContexts
=
await
Promise
.
all
(
let
relevantContexts
=
await
Promise
.
all
(
docs
.
map
(
async
(
doc
)
=>
{
docs
.
map
(
async
(
doc
)
=>
{
if
(
doc
.
type
===
'collection'
)
{
if
(
doc
.
type
===
'collection'
)
{
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
,
4
).
catch
(
return
await
queryCollection
(
localStorage
.
token
,
doc
.
collection_names
,
query
).
catch
(
(
error
)
=>
{
(
error
)
=>
{
console
.
log
(
error
);
console
.
log
(
error
);
return
null
;
return
null
;
}
}
);
);
}
else
{
}
else
{
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
,
4
).
catch
(
return
await
queryDoc
(
localStorage
.
token
,
doc
.
collection_name
,
query
).
catch
((
error
)
=>
{
(
error
)
=>
{
console
.
log
(
error
);
console
.
log
(
error
);
return
null
;
return
null
;
});
}
);
}
}
})
})
);
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment