Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
edbd07f8
Commit
edbd07f8
authored
Jun 27, 2024
by
Timothy J. Baek
Browse files
feat: global filter
parent
c8c85ba7
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
212 additions
and
29 deletions
+212
-29
backend/apps/webui/internal/migrations/018_add_function_is_global.py
...s/webui/internal/migrations/018_add_function_is_global.py
+49
-0
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+13
-0
backend/apps/webui/routers/functions.py
backend/apps/webui/routers/functions.py
+27
-0
backend/main.py
backend/main.py
+24
-25
src/lib/apis/functions/index.ts
src/lib/apis/functions/index.ts
+32
-0
src/lib/components/icons/GlobeAlt.svelte
src/lib/components/icons/GlobeAlt.svelte
+19
-0
src/lib/components/workspace/Functions.svelte
src/lib/components/workspace/Functions.svelte
+24
-1
src/lib/components/workspace/Functions/FunctionMenu.svelte
src/lib/components/workspace/Functions/FunctionMenu.svelte
+24
-3
No files found.
backend/apps/webui/internal/migrations/018_add_function_is_global.py
0 → 100644
View file @
edbd07f8
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
import
peewee
as
pw
from
peewee_migrate
import
Migrator
with
suppress
(
ImportError
):
import
playhouse.postgres_ext
as
pw_pext
def
migrate
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your migrations here."""
migrator
.
add_fields
(
"function"
,
is_global
=
pw
.
BooleanField
(
default
=
False
),
)
def
rollback
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your rollback migrations here."""
migrator
.
remove_fields
(
"function"
,
"is_global"
)
backend/apps/webui/models/functions.py
View file @
edbd07f8
...
@@ -30,6 +30,7 @@ class Function(Model):
...
@@ -30,6 +30,7 @@ class Function(Model):
meta
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
valves
=
JSONField
()
is_active
=
BooleanField
(
default
=
False
)
is_active
=
BooleanField
(
default
=
False
)
is_global
=
BooleanField
(
default
=
False
)
updated_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
...
@@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
...
@@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
content
:
str
content
:
str
meta
:
FunctionMeta
meta
:
FunctionMeta
is_active
:
bool
=
False
is_active
:
bool
=
False
is_global
:
bool
=
False
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
...
@@ -66,6 +68,7 @@ class FunctionResponse(BaseModel):
...
@@ -66,6 +68,7 @@ class FunctionResponse(BaseModel):
name
:
str
name
:
str
meta
:
FunctionMeta
meta
:
FunctionMeta
is_active
:
bool
is_active
:
bool
is_global
:
bool
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
...
@@ -144,6 +147,16 @@ class FunctionsTable:
...
@@ -144,6 +147,16 @@ class FunctionsTable:
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
)
]
]
def
get_global_filter_functions
(
self
)
->
List
[
FunctionModel
]:
return
[
FunctionModel
(
**
model_to_dict
(
function
))
for
function
in
Function
.
select
().
where
(
Function
.
type
==
"filter"
,
Function
.
is_active
==
True
,
Function
.
is_global
==
True
,
)
]
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
function
=
Function
.
get
(
Function
.
id
==
id
)
...
...
backend/apps/webui/routers/functions.py
View file @
edbd07f8
...
@@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
...
@@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
)
)
############################
# ToggleGlobalById
############################
@
router
.
post
(
"/id/{id}/toggle/global"
,
response_model
=
Optional
[
FunctionModel
])
async
def
toggle_global_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)):
function
=
Functions
.
get_function_by_id
(
id
)
if
function
:
function
=
Functions
.
update_function_by_id
(
id
,
{
"is_global"
:
not
function
.
is_global
}
)
if
function
:
return
function
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
"Error updating function"
),
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
NOT_FOUND
,
)
############################
############################
# UpdateFunctionById
# UpdateFunctionById
############################
############################
...
...
backend/main.py
View file @
edbd07f8
...
@@ -416,8 +416,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -416,8 +416,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
)
return
0
return
0
filter_ids
=
[]
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()
]
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
filter_ids
.
extend
(
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[]))
filter_ids
=
list
(
set
(
filter_ids
))
enabled_filter_ids
=
[
enabled_filter_ids
=
[
function
.
id
function
.
id
for
function
in
Functions
.
get_functions_by_type
(
for
function
in
Functions
.
get_functions_by_type
(
...
@@ -425,11 +430,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -425,11 +430,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
)
]
]
filter_ids
=
[
filter_ids
=
[
filter_id
filter_id
for
filter_id
in
filter_ids
if
filter_id
in
enabled_filter_ids
for
filter_id
in
enabled_filter_ids
if
filter_id
in
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[])
]
]
filter_ids
=
list
(
set
(
filter_ids
))
filter_ids
.
sort
(
key
=
get_priority
)
filter_ids
.
sort
(
key
=
get_priority
)
for
filter_id
in
filter_ids
:
for
filter_id
in
filter_ids
:
...
@@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
...
@@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
)
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
print
(
model
)
pipe
=
model
.
get
(
"pipe"
)
pipe
=
model
.
get
(
"pipe"
)
if
pipe
:
if
pipe
:
...
@@ -1010,20 +1011,18 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1010,20 +1011,18 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
0
return
0
filter_ids
=
[]
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()
]
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
filter_ids
.
extend
(
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[]))
filter_ids
=
list
(
set
(
filter_ids
))
enabled_filter_ids
=
[
enabled_filter_ids
=
[
function
.
id
function
.
id
for
function
in
Functions
.
get_functions_by_type
(
for
function
in
Functions
.
get_functions_by_type
(
"filter"
,
active_only
=
True
)
"filter"
,
active_only
=
True
)
]
]
filter_ids
=
[
filter_ids
=
[
filter_id
filter_id
for
filter_id
in
filter_ids
if
filter_id
in
enabled_filter_ids
for
filter_id
in
enabled_filter_ids
if
filter_id
in
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[])
]
]
filter_ids
=
list
(
set
(
filter_ids
))
# Sort filter_ids by priority, using the get_priority function
# Sort filter_ids by priority, using the get_priority function
filter_ids
.
sort
(
key
=
get_priority
)
filter_ids
.
sort
(
key
=
get_priority
)
...
...
src/lib/apis/functions/index.ts
View file @
edbd07f8
...
@@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => {
...
@@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => {
return
res
;
return
res
;
};
};
export
const
toggleGlobalById
=
async
(
token
:
string
,
id
:
string
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/functions/id/
${
id
}
/toggle/global`
,
{
method
:
'
POST
'
,
headers
:
{
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
authorization
:
`Bearer
${
token
}
`
}
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
then
((
json
)
=>
{
return
json
;
})
.
catch
((
err
)
=>
{
error
=
err
.
detail
;
console
.
log
(
err
);
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
;
};
export
const
getFunctionValvesById
=
async
(
token
:
string
,
id
:
string
)
=>
{
export
const
getFunctionValvesById
=
async
(
token
:
string
,
id
:
string
)
=>
{
let
error
=
null
;
let
error
=
null
;
...
...
src/lib/components/icons/GlobeAlt.svelte
0 → 100644
View file @
edbd07f8
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width={strokeWidth}
stroke="currentColor"
class={className}
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 21a9.004 9.004 0 0 0 8.716-6.747M12 21a9.004 9.004 0 0 1-8.716-6.747M12 21c2.485 0 4.5-4.03 4.5-9S14.485 3 12 3m0 18c-2.485 0-4.5-4.03-4.5-9S9.515 3 12 3m0 0a8.997 8.997 0 0 1 7.843 4.582M12 3a8.997 8.997 0 0 0-7.843 4.582m15.686 0A11.953 11.953 0 0 1 12 10.5c-2.998 0-5.74-1.1-7.843-2.918m15.686 0A8.959 8.959 0 0 1 21 12c0 .778-.099 1.533-.284 2.253m0 0A17.919 17.919 0 0 1 12 16.5c-3.162 0-6.133-.815-8.716-2.247m0 0A9.015 9.015 0 0 1 3 12c0-1.605.42-3.113 1.157-4.418"
/>
</svg>
src/lib/components/workspace/Functions.svelte
View file @
edbd07f8
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
exportFunctions,
exportFunctions,
getFunctionById,
getFunctionById,
getFunctions,
getFunctions,
toggleFunctionById
toggleFunctionById,
toggleGlobalById
} from '$lib/apis/functions';
} from '$lib/apis/functions';
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
...
@@ -113,6 +114,22 @@
...
@@ -113,6 +114,22 @@
models.set(await getModels(localStorage.token));
models.set(await getModels(localStorage.token));
}
}
};
};
const toggleGlobalHandler = async (func) => {
const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => {
toast.error(error);
});
if (res) {
if (func.is_global) {
toast.success($i18n.t('Filter is now globally enabled'));
} else {
toast.success($i18n.t('Filter is now globally disabled'));
}
functions.set(await getFunctions(localStorage.token));
}
};
</script>
</script>
<svelte:head>
<svelte:head>
...
@@ -259,6 +276,7 @@
...
@@ -259,6 +276,7 @@
</Tooltip>
</Tooltip>
<FunctionMenu
<FunctionMenu
{func}
editHandler={() => {
editHandler={() => {
goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
}}
}}
...
@@ -275,6 +293,11 @@
...
@@ -275,6 +293,11 @@
selectedFunction = func;
selectedFunction = func;
showDeleteConfirm = true;
showDeleteConfirm = true;
}}
}}
toggleGlobalHandler={() => {
if (func.type === 'filter') {
toggleGlobalHandler(func);
}
}}
onClose={() => {}}
onClose={() => {}}
>
>
<button
<button
...
...
src/lib/components/workspace/Functions/FunctionMenu.svelte
View file @
edbd07f8
...
@@ -5,21 +5,24 @@
...
@@ -5,21 +5,24 @@
import Dropdown from '$lib/components/common/Dropdown.svelte';
import Dropdown from '$lib/components/common/Dropdown.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import Pencil from '$lib/components/icons/Pencil.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Tags from '$lib/components/chat/Tags.svelte';
import Share from '$lib/components/icons/Share.svelte';
import Share from '$lib/components/icons/Share.svelte';
import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
import ArrowDownTray from '$lib/components/icons/ArrowDownTray.svelte';
import ArrowDownTray from '$lib/components/icons/ArrowDownTray.svelte';
import Switch from '$lib/components/common/Switch.svelte';
import GlobeAlt from '$lib/components/icons/GlobeAlt.svelte';
const i18n = getContext('i18n');
const i18n = getContext('i18n');
export let func;
export let editHandler: Function;
export let editHandler: Function;
export let shareHandler: Function;
export let shareHandler: Function;
export let cloneHandler: Function;
export let cloneHandler: Function;
export let exportHandler: Function;
export let exportHandler: Function;
export let deleteHandler: Function;
export let deleteHandler: Function;
export let toggleGlobalHandler: Function;
export let onClose: Function;
export let onClose: Function;
let show = false;
let show = false;
...
@@ -45,6 +48,24 @@
...
@@ -45,6 +48,24 @@
align="start"
align="start"
transition={flyAndScale}
transition={flyAndScale}
>
>
{#if func.type === 'filter'}
<div
class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
>
<div class="flex gap-2 items-center">
<GlobeAlt />
<div class="flex items-center">{$i18n.t('Global')}</div>
</div>
<div>
<Switch on:change={toggleGlobalHandler} bind:state={func.is_global} />
</div>
</div>
<hr class="border-gray-100 dark:border-gray-800 my-1" />
{/if}
<DropdownMenu.Item
<DropdownMenu.Item
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
on:click={() => {
on:click={() => {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment