Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
591cd993
Commit
591cd993
authored
Jun 09, 2024
by
Timothy J. Baek
Browse files
refac: search query task
parent
aa1bb4fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
29 deletions
+223
-29
backend/config.py
backend/config.py
+25
-0
backend/main.py
backend/main.py
+113
-3
backend/utils/task.py
backend/utils/task.py
+42
-0
src/lib/apis/index.ts
src/lib/apis/index.ts
+40
-0
src/lib/components/chat/Chat.svelte
src/lib/components/chat/Chat.svelte
+3
-26
No files found.
backend/config.py
View file @
591cd993
...
@@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
...
@@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
)
)
TASK_MODEL
=
PersistentConfig
(
"TASK_MODEL"
,
"task.model.default"
,
os
.
environ
.
get
(
"TASK_MODEL"
,
""
),
)
TASK_MODEL_EXTERNAL
=
PersistentConfig
(
"TASK_MODEL_EXTERNAL"
,
"task.model.external"
,
os
.
environ
.
get
(
"TASK_MODEL_EXTERNAL"
,
""
),
)
TITLE_GENERATION_PROMPT_TEMPLATE
=
PersistentConfig
(
TITLE_GENERATION_PROMPT_TEMPLATE
=
PersistentConfig
(
"TITLE_GENERATION_PROMPT_TEMPLATE"
,
"TITLE_GENERATION_PROMPT_TEMPLATE"
,
"task.title.prompt_template"
,
"task.title.prompt_template"
,
...
@@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
...
@@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
)
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
=
PersistentConfig
(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE"
,
"task.search.prompt_template"
,
os
.
environ
.
get
(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE"
,
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}"""
,
),
)
####################################
####################################
# WEBUI_SECRET_KEY
# WEBUI_SECRET_KEY
####################################
####################################
...
...
backend/main.py
View file @
591cd993
...
@@ -53,7 +53,7 @@ from utils.utils import (
...
@@ -53,7 +53,7 @@ from utils.utils import (
get_current_user
,
get_current_user
,
get_http_authorization_cred
,
get_http_authorization_cred
,
)
)
from
utils.task
import
title_generation_template
from
utils.task
import
title_generation_template
,
search_query_generation_template
from
apps.rag.utils
import
rag_messages
from
apps.rag.utils
import
rag_messages
...
@@ -77,7 +77,10 @@ from config import (
...
@@ -77,7 +77,10 @@ from config import (
WEBHOOK_URL
,
WEBHOOK_URL
,
ENABLE_ADMIN_EXPORT
,
ENABLE_ADMIN_EXPORT
,
WEBUI_BUILD_HASH
,
WEBUI_BUILD_HASH
,
TASK_MODEL
,
TASK_MODEL_EXTERNAL
,
TITLE_GENERATION_PROMPT_TEMPLATE
,
TITLE_GENERATION_PROMPT_TEMPLATE
,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
,
AppConfig
,
AppConfig
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
...
@@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
...
@@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
TASK_MODEL
=
TASK_MODEL
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
=
TASK_MODEL_EXTERNAL
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
=
TITLE_GENERATION_PROMPT_TEMPLATE
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
=
TITLE_GENERATION_PROMPT_TEMPLATE
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
=
(
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
...
@@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
...
@@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
return
{
"data"
:
models
}
return
{
"data"
:
models
}
@
app
.
get
(
"/api/task/config"
)
async
def
get_task_config
(
user
=
Depends
(
get_verified_user
)):
return
{
"TASK_MODEL"
:
app
.
state
.
config
.
TASK_MODEL
,
"TASK_MODEL_EXTERNAL"
:
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
,
"TITLE_GENERATION_PROMPT_TEMPLATE"
:
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE"
:
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
,
}
class
TaskConfigForm
(
BaseModel
):
TASK_MODEL
:
Optional
[
str
]
TASK_MODEL_EXTERNAL
:
Optional
[
str
]
TITLE_GENERATION_PROMPT_TEMPLATE
:
str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
:
str
@
app
.
post
(
"/api/task/config/update"
)
async
def
update_task_config
(
form_data
:
TaskConfigForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
config
.
TASK_MODEL
=
form_data
.
TASK_MODEL
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
=
form_data
.
TASK_MODEL_EXTERNAL
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
=
(
form_data
.
TITLE_GENERATION_PROMPT_TEMPLATE
)
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
=
(
form_data
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
return
{
"TASK_MODEL"
:
app
.
state
.
config
.
TASK_MODEL
,
"TASK_MODEL_EXTERNAL"
:
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
,
"TITLE_GENERATION_PROMPT_TEMPLATE"
:
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE"
:
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
,
}
@
app
.
post
(
"/api/task/title/completions"
)
@
app
.
post
(
"/api/task/title/completions"
)
async
def
generate_title
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)):
async
def
generate_title
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)):
print
(
"generate_title"
)
print
(
"generate_title"
)
model_id
=
form_data
[
"model"
]
model_id
=
form_data
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
if
model_id
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
raise
HTTPException
(
...
@@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
...
@@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
detail
=
"Model not found"
,
detail
=
"Model not found"
,
)
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
template
=
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
...
@@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
...
@@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
return
await
generate_openai_chat_completion
(
payload
,
user
=
user
)
return
await
generate_openai_chat_completion
(
payload
,
user
=
user
)
@
app
.
post
(
"/api/task/query/completions"
)
async
def
generate_search_query
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)):
print
(
"generate_search_query"
)
model_id
=
form_data
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
status
.
HTTP_404_NOT_FOUND
,
detail
=
"Model not found"
,
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content
=
search_query_generation_template
(
template
,
form_data
[
"prompt"
],
user
.
model_dump
()
)
payload
=
{
"model"
:
model_id
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
content
}],
"stream"
:
False
,
"max_tokens"
:
30
,
}
print
(
payload
)
payload
=
filter_pipeline
(
payload
,
user
)
if
model
[
"owned_by"
]
==
"ollama"
:
return
await
generate_ollama_chat_completion
(
OpenAIChatCompletionForm
(
**
payload
),
user
=
user
)
else
:
return
await
generate_openai_chat_completion
(
payload
,
user
=
user
)
@
app
.
post
(
"/api/chat/completions"
)
@
app
.
post
(
"/api/chat/completions"
)
async
def
generate_chat_completions
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)):
async
def
generate_chat_completions
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)):
model_id
=
form_data
[
"model"
]
model_id
=
form_data
[
"model"
]
...
@@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
...
@@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
)
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
print
(
model
)
print
(
model
)
if
model
[
"owned_by"
]
==
"ollama"
:
if
model
[
"owned_by"
]
==
"ollama"
:
...
...
backend/utils/task.py
View file @
591cd993
...
@@ -68,3 +68,45 @@ def title_generation_template(
...
@@ -68,3 +68,45 @@ def title_generation_template(
)
)
return
template
return
template
def
search_query_generation_template
(
template
:
str
,
prompt
:
str
,
user
:
Optional
[
dict
]
=
None
)
->
str
:
def
replacement_function
(
match
):
full_match
=
match
.
group
(
0
)
start_length
=
match
.
group
(
1
)
end_length
=
match
.
group
(
2
)
middle_length
=
match
.
group
(
3
)
if
full_match
==
"{{prompt}}"
:
return
prompt
elif
start_length
is
not
None
:
return
prompt
[:
int
(
start_length
)]
elif
end_length
is
not
None
:
return
prompt
[
-
int
(
end_length
)
:]
elif
middle_length
is
not
None
:
middle_length
=
int
(
middle_length
)
if
len
(
prompt
)
<=
middle_length
:
return
prompt
start
=
prompt
[:
math
.
ceil
(
middle_length
/
2
)]
end
=
prompt
[
-
math
.
floor
(
middle_length
/
2
)
:]
return
f
"
{
start
}
...
{
end
}
"
return
""
template
=
re
.
sub
(
r
"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}"
,
replacement_function
,
template
,
)
template
=
prompt_template
(
template
,
**
(
{
"user_name"
:
user
.
get
(
"name"
),
"current_location"
:
user
.
get
(
"location"
)}
if
user
else
{}
),
)
return
template
src/lib/apis/index.ts
View file @
591cd993
...
@@ -144,6 +144,46 @@ export const generateTitle = async (
...
@@ -144,6 +144,46 @@ export const generateTitle = async (
return
res
?.
choices
[
0
]?.
message
?.
content
.
replace
(
/
[
"'
]
/g
,
''
)
??
'
New Chat
'
;
return
res
?.
choices
[
0
]?.
message
?.
content
.
replace
(
/
[
"'
]
/g
,
''
)
??
'
New Chat
'
;
};
};
export
const
generateSearchQuery
=
async
(
token
:
string
=
''
,
model
:
string
,
messages
:
object
[],
prompt
:
string
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_BASE_URL
}
/api/task/query/completions`
,
{
method
:
'
POST
'
,
headers
:
{
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
Authorization
:
`Bearer
${
token
}
`
},
body
:
JSON
.
stringify
({
model
:
model
,
messages
:
messages
,
prompt
:
prompt
})
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
if
(
'
detail
'
in
err
)
{
error
=
err
.
detail
;
}
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
?.
choices
[
0
]?.
message
?.
content
.
replace
(
/
[
"'
]
/g
,
''
)
??
prompt
;
};
export
const
getPipelinesList
=
async
(
token
:
string
=
''
)
=>
{
export
const
getPipelinesList
=
async
(
token
:
string
=
''
)
=>
{
let
error
=
null
;
let
error
=
null
;
...
...
src/lib/components/chat/Chat.svelte
View file @
591cd993
...
@@ -44,12 +44,12 @@
...
@@ -44,12 +44,12 @@
getTagsById
,
getTagsById
,
updateChatById
updateChatById
}
from
'$lib/apis/chats'
;
}
from
'$lib/apis/chats'
;
import
{
generateOpenAIChatCompletion
,
generateSearchQuery
}
from
'$lib/apis/openai'
;
import
{
generateOpenAIChatCompletion
}
from
'$lib/apis/openai'
;
import
{
runWebSearch
}
from
'$lib/apis/rag'
;
import
{
runWebSearch
}
from
'$lib/apis/rag'
;
import
{
createOpenAITextStream
}
from
'$lib/apis/streaming'
;
import
{
createOpenAITextStream
}
from
'$lib/apis/streaming'
;
import
{
queryMemory
}
from
'$lib/apis/memories'
;
import
{
queryMemory
}
from
'$lib/apis/memories'
;
import
{
getUserSettings
}
from
'$lib/apis/users'
;
import
{
getUserSettings
}
from
'$lib/apis/users'
;
import
{
chatCompleted
,
generateTitle
}
from
'$lib/apis'
;
import
{
chatCompleted
,
generateTitle
,
generateSearchQuery
}
from
'$lib/apis'
;
import
Banner
from
'../common/Banner.svelte'
;
import
Banner
from
'../common/Banner.svelte'
;
import
MessageInput
from
'$lib/components/chat/MessageInput.svelte'
;
import
MessageInput
from
'$lib/components/chat/MessageInput.svelte'
;
...
@@ -508,7 +508,7 @@
...
@@ -508,7 +508,7 @@
const
prompt
=
history
.
messages
[
parentId
].
content
;
const
prompt
=
history
.
messages
[
parentId
].
content
;
let
searchQuery
=
prompt
;
let
searchQuery
=
prompt
;
if
(
prompt
.
length
>
100
)
{
if
(
prompt
.
length
>
100
)
{
searchQuery
=
await
generate
Chat
SearchQuery
(
model
,
prompt
);
searchQuery
=
await
generateSearchQuery
(
localStorage
.
token
,
model
,
messages
,
prompt
);
if
(
!searchQuery) {
if
(
!searchQuery) {
toast
.
warning
($
i18n
.
t
(
'No search query generated'
));
toast
.
warning
($
i18n
.
t
(
'No search query generated'
));
responseMessage
.
status
=
{
responseMessage
.
status
=
{
...
@@ -1129,29 +1129,6 @@
...
@@ -1129,29 +1129,6 @@
}
}
};
};
const generateChatSearchQuery = async (modelId: string, prompt: string) => {
const model = $models.find((model) => model.id === modelId);
const taskModelId =
model?.owned_by === '
openai
' ?? false
? $settings?.title?.modelExternal ?? modelId
: $settings?.title?.model ?? modelId;
const taskModel = $models.find((model) => model.id === taskModelId);
const previousMessages = messages
.filter((message) => message.role === '
user
')
.map((message) => message.content);
return await generateSearchQuery(
localStorage.token,
taskModelId,
previousMessages,
prompt,
taskModel?.owned_by === '
openai
' ?? false
? `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
);
};
const setChatTitle = async (_chatId, _title) => {
const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) {
if (_chatId === $chatId) {
title = _title;
title = _title;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment