Unverified Commit 09a81eb2 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3321 from open-webui/functions

feat: functions
parents f9283bc3 9ebd308d
...@@ -53,7 +53,7 @@ from config import ( ...@@ -53,7 +53,7 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
...@@ -834,18 +834,9 @@ async def generate_chat_completion( ...@@ -834,18 +834,9 @@ async def generate_chat_completion(
) )
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: payload["messages"] = add_or_update_system_message(
if message.get("role") == "system": system, payload["messages"]
message["content"] = system + message["content"] )
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx == None: if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
......
...@@ -432,7 +432,12 @@ async def generate_chat_completion( ...@@ -432,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"] idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id} payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
......
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")
...@@ -13,7 +13,11 @@ from apps.webui.routers import ( ...@@ -13,7 +13,11 @@ from apps.webui.routers import (
memories, memories,
utils, utils,
files, files,
functions,
) )
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import ( from config import (
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
...@@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING ...@@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -70,19 +74,22 @@ app.add_middleware( ...@@ -70,19 +74,22 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
app.include_router(files.router, prefix="/files", tags=["files"])
@app.get("/") @app.get("/")
...@@ -93,3 +100,58 @@ async def get_status(): ...@@ -93,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models
...@@ -55,6 +55,7 @@ class FunctionModel(BaseModel): ...@@ -55,6 +55,7 @@ class FunctionModel(BaseModel):
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
id: str id: str
user_id: str user_id: str
type: str
name: str name: str
meta: FunctionMeta meta: FunctionMeta
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
...@@ -64,23 +65,23 @@ class FunctionResponse(BaseModel): ...@@ -64,23 +65,23 @@ class FunctionResponse(BaseModel):
class FunctionForm(BaseModel): class FunctionForm(BaseModel):
id: str id: str
name: str name: str
type: str
content: str content: str
meta: FunctionMeta meta: FunctionMeta
class ToolsTable: class FunctionsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Function]) self.db.create_tables([Function])
def insert_new_function( def insert_new_function(
self, user_id: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
"user_id": user_id, "user_id": user_id,
"type": type,
"updated_at": int(time.time()), "updated_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
} }
...@@ -137,4 +138,4 @@ class ToolsTable: ...@@ -137,4 +138,4 @@ class ToolsTable:
return False return False
Tools = ToolsTable(DB) Functions = FunctionsTable(DB)
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.functions import (
Functions,
FunctionForm,
FunctionModel,
FunctionResponse,
)
from apps.webui.utils import load_function_module_by_id
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
router = APIRouter()
############################
# GetFunctions
############################
@router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
############################
# ExportFunctions
############################
@router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# CreateNewFunction
############################
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetFunctionById
############################
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
function = Functions.update_function_by_id(id, updated)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteFunctionById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
if id in FUNCTIONS:
del FUNCTIONS[id]
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path)
return result
from importlib import util from importlib import util
import os import os
from config import TOOLS_DIR from config import TOOLS_DIR, FUNCTIONS_DIR
def load_toolkit_module_by_id(toolkit_id): def load_toolkit_module_by_id(toolkit_id):
...@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id): ...@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
# Move the file to the error folder # Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error") os.rename(toolkit_path, f"{toolkit_path}.error")
raise e raise e
def load_function_module_by_id(function_id):
function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
return module.Pipe(), "pipe"
elif hasattr(module, "Filter"):
return module.Filter(), "filter"
else:
raise Exception("No Function class found")
except Exception as e:
print(f"Error loading module: {function_id}")
# Move the file to the error folder
os.rename(function_path, f"{function_path}.error")
raise e
...@@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") ...@@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################
......
This diff is collapsed.
...@@ -4,6 +4,8 @@ import json ...@@ -4,6 +4,8 @@ import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
...@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): ...@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters
......
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFunction = async (token: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionById = async (token: string, id: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -278,7 +278,9 @@ ...@@ -278,7 +278,9 @@
})), })),
chat_id: $chatId chat_id: $chatId
}).catch((error) => { }).catch((error) => {
console.error(error); toast.error(error);
messages.at(-1).error = { content: error };
return null; return null;
}); });
...@@ -323,6 +325,13 @@ ...@@ -323,6 +325,13 @@
} else if (messages.length != 0 && messages.at(-1).done != true) { } else if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done // Response not done
console.log('wait'); console.log('wait');
} else if (messages.length != 0 && messages.at(-1).error) {
// Error in response
toast.error(
$i18n.t(
`Oops! There was an error in the previous response. Please try again or contact admin.`
)
);
} else if ( } else if (
files.length > 0 && files.length > 0 &&
files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0 files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
...@@ -630,7 +639,7 @@ ...@@ -630,7 +639,7 @@
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0, citations: files.length > 0 ? true : undefined,
chat_id: $chatId chat_id: $chatId
}); });
...@@ -928,10 +937,11 @@ ...@@ -928,10 +937,11 @@
max_tokens: $settings?.params?.max_tokens ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0, citations: files.length > 0 ? true : undefined,
chat_id: $chatId chat_id: $chatId
}, },
`${OPENAI_API_BASE_URL}` `${WEBUI_BASE_URL}/api`
); );
// Wait until history/message have been updated // Wait until history/message have been updated
......
...@@ -3,25 +3,27 @@ ...@@ -3,25 +3,27 @@
import fileSaver from 'file-saver'; import fileSaver from 'file-saver';
const { saveAs } = fileSaver; const { saveAs } = fileSaver;
import { WEBUI_NAME, functions, models } from '$lib/stores';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, prompts, tools } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { import {
createNewTool, createNewFunction,
deleteToolById, deleteFunctionById,
exportTools, exportFunctions,
getToolById, getFunctionById,
getTools getFunctions
} from '$lib/apis/tools'; } from '$lib/apis/functions';
import ArrowDownTray from '../icons/ArrowDownTray.svelte'; import ArrowDownTray from '../icons/ArrowDownTray.svelte';
import Tooltip from '../common/Tooltip.svelte'; import Tooltip from '../common/Tooltip.svelte';
import ConfirmDialog from '../common/ConfirmDialog.svelte'; import ConfirmDialog from '../common/ConfirmDialog.svelte';
import { getModels } from '$lib/apis';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
let toolsImportInputElement: HTMLInputElement; let functionsImportInputElement: HTMLInputElement;
let importFiles; let importFiles;
let showConfirm = false; let showConfirm = false;
...@@ -64,7 +66,7 @@ ...@@ -64,7 +66,7 @@
<div> <div>
<a <a
class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1" class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
href="/workspace/tools/create" href="/workspace/functions/create"
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
...@@ -82,30 +84,40 @@ ...@@ -82,30 +84,40 @@
<hr class=" dark:border-gray-850 my-2.5" /> <hr class=" dark:border-gray-850 my-2.5" />
<div class="my-3 mb-5"> <div class="my-3 mb-5">
{#each $tools.filter((t) => query === '' || t.name {#each $functions.filter((f) => query === '' || f.name
.toLowerCase() .toLowerCase()
.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool} .includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
<button <button
class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl" class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
type="button" type="button"
on:click={() => { on:click={() => {
goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`); goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
}} }}
> >
<div class=" flex flex-1 space-x-4 cursor-pointer w-full"> <div class=" flex flex-1 space-x-4 cursor-pointer w-full">
<a <a
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`} href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
class="flex items-center text-left" class="flex items-center text-left"
> >
<div class=" flex-1 self-center pl-5"> <div class=" flex-1 self-center pl-1">
<div class=" font-semibold flex items-center gap-1.5"> <div class=" font-semibold flex items-center gap-1.5">
<div
class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
>
{func.type}
</div>
<div> <div>
{tool.name} {func.name}
</div> </div>
<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
</div> </div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{tool.meta.description} <div class="flex gap-1.5 px-1">
<div class=" text-gray-500 text-xs font-medium">{func.id}</div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{func.meta.description}
</div>
</div> </div>
</div> </div>
</a> </a>
...@@ -115,7 +127,7 @@ ...@@ -115,7 +127,7 @@
<a <a
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button" type="button"
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`} href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
...@@ -141,18 +153,20 @@ ...@@ -141,18 +153,20 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => { const _function = await getFunctionById(localStorage.token, func.id).catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
if (_tool) { );
sessionStorage.tool = JSON.stringify({
..._tool, if (_function) {
id: `${_tool.id}_clone`, sessionStorage.function = JSON.stringify({
name: `${_tool.name} (Clone)` ..._function,
id: `${_function.id}_clone`,
name: `${_function.name} (Clone)`
}); });
goto('/workspace/tools/create'); goto('/workspace/functions/create');
} }
}} }}
> >
...@@ -180,16 +194,18 @@ ...@@ -180,16 +194,18 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => { const _function = await getFunctionById(localStorage.token, func.id).catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (_tool) { if (_function) {
let blob = new Blob([JSON.stringify([_tool])], { let blob = new Blob([JSON.stringify([_function])], {
type: 'application/json' type: 'application/json'
}); });
saveAs(blob, `tool-${_tool.id}-export-${Date.now()}.json`); saveAs(blob, `function-${_function.id}-export-${Date.now()}.json`);
} }
}} }}
> >
...@@ -204,14 +220,16 @@ ...@@ -204,14 +220,16 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const res = await deleteToolById(localStorage.token, tool.id).catch((error) => { const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
if (res) { if (res) {
toast.success('Tool deleted successfully'); toast.success('Function deleted successfully');
tools.set(await getTools(localStorage.token));
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
} }
}} }}
> >
...@@ -246,7 +264,7 @@ ...@@ -246,7 +264,7 @@
<div class="flex space-x-2"> <div class="flex space-x-2">
<input <input
id="documents-import-input" id="documents-import-input"
bind:this={toolsImportInputElement} bind:this={functionsImportInputElement}
bind:files={importFiles} bind:files={importFiles}
type="file" type="file"
accept=".json" accept=".json"
...@@ -260,7 +278,7 @@ ...@@ -260,7 +278,7 @@
<button <button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition" class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={() => { on:click={() => {
toolsImportInputElement.click(); functionsImportInputElement.click();
}} }}
> >
<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div> <div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
...@@ -284,16 +302,16 @@ ...@@ -284,16 +302,16 @@
<button <button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition" class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={async () => { on:click={async () => {
const _tools = await exportTools(localStorage.token).catch((error) => { const _functions = await exportFunctions(localStorage.token).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
if (_tools) { if (_functions) {
let blob = new Blob([JSON.stringify(_tools)], { let blob = new Blob([JSON.stringify(_functions)], {
type: 'application/json' type: 'application/json'
}); });
saveAs(blob, `tools-export-${Date.now()}.json`); saveAs(blob, `functions-export-${Date.now()}.json`);
} }
}} }}
> >
...@@ -322,18 +340,19 @@ ...@@ -322,18 +340,19 @@
on:confirm={() => { on:confirm={() => {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async (event) => { reader.onload = async (event) => {
const _tools = JSON.parse(event.target.result); const _functions = JSON.parse(event.target.result);
console.log(_tools); console.log(_functions);
for (const tool of _tools) { for (const func of _functions) {
const res = await createNewTool(localStorage.token, tool).catch((error) => { const res = await createNewFunction(localStorage.token, func).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
} }
toast.success('Tool imported successfully'); toast.success('Functions imported successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
}; };
reader.readAsText(importFiles[0]); reader.readAsText(importFiles[0]);
...@@ -344,8 +363,8 @@ ...@@ -344,8 +363,8 @@
<div>Please carefully review the following warnings:</div> <div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs"> <ul class=" mt-1 list-disc pl-4 text-xs">
<li>Tools have a function calling system that allows arbitrary code execution.</li> <li>Functions allow arbitrary code execution.</li>
<li>Do not install tools from sources you do not fully trust.</li> <li>Do not install functions from sources you do not fully trust.</li>
</ul> </ul>
</div> </div>
......
<script>
import { getContext, createEventDispatcher, onMount } from 'svelte';
import { goto } from '$app/navigation';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import CodeEditor from '$lib/components/common/CodeEditor.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
let formElement = null;
let loading = false;
let showConfirm = false;
export let edit = false;
export let clone = false;
export let id = '';
export let name = '';
export let meta = {
description: ''
};
export let content = '';
$: if (name && !edit && !clone) {
id = name.replace(/\s+/g, '_').toLowerCase();
}
let codeEditor;
let boilerplate = `from pydantic import BaseModel
from typing import Optional
class Filter:
class Valves(BaseModel):
max_turns: int = 4
pass
def __init__(self):
# Indicates custom file handling logic. This flag helps disengage default routines in favor of custom
# implementations, informing the WebUI to defer file-related operations to designated methods within this class.
# Alternatively, you can remove the files directly from the body in from the inlet hook
self.file_handler = True
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
self.valves = self.Valves(**{"max_turns": 2})
pass
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify the request body or validate it before processing by the chat completion API.
# This function is the pre-processor for the API where various checks on the input can be performed.
# It can also modify the request before sending it to the API.
print(f"inlet:{__name__}")
print(f"inlet:body:{body}")
print(f"inlet:user:{user}")
if user.get("role", "admin") in ["user", "admin"]:
messages = body.get("messages", [])
if len(messages) > self.valves.max_turns:
raise Exception(
f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}"
)
return body
def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify or analyze the response body after processing by the API.
# This function is the post-processor for the API, which can be used to modify the response
# or perform additional checks and analytics.
print(f"outlet:{__name__}")
print(f"outlet:body:{body}")
print(f"outlet:user:{user}")
messages = [
{
**message,
"content": f"{message['content']} - @@Modified from Filter Outlet",
}
for message in body.get("messages", [])
]
return {"messages": messages}
`;
const saveHandler = async () => {
loading = true;
dispatch('save', {
id,
name,
meta,
content
});
};
const submitHandler = async () => {
if (codeEditor) {
const res = await codeEditor.formatPythonCodeHandler();
if (res) {
console.log('Code formatted successfully');
saveHandler();
}
}
};
</script>
<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
<div class="mx-auto w-full md:px-0 h-full">
<form
bind:this={formElement}
class=" flex flex-col max-h-[100dvh] h-full"
on:submit|preventDefault={() => {
if (edit) {
submitHandler();
} else {
showConfirm = true;
}
}}
>
<div class="mb-2.5">
<button
class="flex space-x-1"
on:click={() => {
goto('/workspace/functions');
}}
type="button"
>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M17 10a.75.75 0 01-.75.75H5.612l4.158 3.96a.75.75 0 11-1.04 1.08l-5.5-5.25a.75.75 0 010-1.08l5.5-5.25a.75.75 0 111.04 1.08L5.612 9.25H16.25A.75.75 0 0117 10z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
</button>
</div>
<div class="flex flex-col flex-1 overflow-auto h-0 rounded-lg">
<div class="w-full mb-2 flex flex-col gap-1.5">
<div class="flex gap-2 w-full">
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Name (e.g. My Filter)"
bind:value={name}
required
/>
<input
class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function ID (e.g. my_filter)"
bind:value={id}
required
disabled={edit}
/>
</div>
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Description (e.g. A filter to remove profanity from text)"
bind:value={meta.description}
required
/>
</div>
<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
<CodeEditor
bind:value={content}
bind:this={codeEditor}
{boilerplate}
on:save={() => {
if (formElement) {
formElement.requestSubmit();
}
}}
/>
</div>
<div class="pb-3 flex justify-between">
<div class="flex-1 pr-3">
<div class="text-xs text-gray-500 line-clamp-2">
<span class=" font-semibold dark:text-gray-200">Warning:</span> Functions allow
arbitrary code execution <br />—
<span class=" font-medium dark:text-gray-400"
>don't install random functions from sources you don't trust.</span
>
</div>
</div>
<button
class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
</div>
</form>
</div>
</div>
<ConfirmDialog
bind:show={showConfirm}
on:confirm={() => {
submitHandler();
}}
>
<div class="text-sm text-gray-500">
<div class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-lg px-4 py-3">
<div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs">
<li>Functions allow arbitrary code execution.</li>
<li>Do not install functions from sources you do not fully trust.</li>
</ul>
</div>
<div class="my-3">
I acknowledge that I have read and I understand the implications of my action. I am aware of
the risks associated with executing arbitrary code and I have verified the trustworthiness of
the source.
</div>
</div>
</ConfirmDialog>
<script lang="ts">
import { getContext, onMount } from 'svelte';
import Checkbox from '$lib/components/common/Checkbox.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
const i18n = getContext('i18n');
export let filters = [];
export let selectedFilterIds = [];
let _filters = {};
onMount(() => {
_filters = filters.reduce((acc, filter) => {
acc[filter.id] = {
...filter,
selected: selectedFilterIds.includes(filter.id)
};
return acc;
}, {});
});
</script>
<div>
<div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
</div>
<div class=" text-xs dark:text-gray-500">
{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
</div>
<!-- TODO: Filer order matters -->
<div class="flex flex-col">
{#if filters.length > 0}
<div class=" flex items-center mt-2 flex-wrap">
{#each Object.keys(_filters) as filter, filterIdx}
<div class=" flex items-center gap-2 mr-3">
<div class="self-center flex items-center">
<Checkbox
state={_filters[filter].selected ? 'checked' : 'unchecked'}
on:change={(e) => {
_filters[filter].selected = e.detail === 'checked';
selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
}}
/>
</div>
<div class=" py-0.5 text-sm w-full capitalize font-medium">
<Tooltip content={_filters[filter].meta.description}>
{_filters[filter].name}
</Tooltip>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>
...@@ -27,7 +27,9 @@ export const tags = writable([]); ...@@ -27,7 +27,9 @@ export const tags = writable([]);
export const models: Writable<Model[]> = writable([]); export const models: Writable<Model[]> = writable([]);
export const prompts: Writable<Prompt[]> = writable([]); export const prompts: Writable<Prompt[]> = writable([]);
export const documents: Writable<Document[]> = writable([]); export const documents: Writable<Document[]> = writable([]);
export const tools = writable([]); export const tools = writable([]);
export const functions = writable([]);
export const banners: Writable<Banner[]> = writable([]); export const banners: Writable<Banner[]> = writable([]);
......
<script lang="ts"> <script lang="ts">
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, showSidebar } from '$lib/stores'; import { WEBUI_NAME, showSidebar, functions } from '$lib/stores';
import MenuLines from '$lib/components/icons/MenuLines.svelte'; import MenuLines from '$lib/components/icons/MenuLines.svelte';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { getFunctions } from '$lib/apis/functions';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
onMount(async () => {
functions.set(await getFunctions(localStorage.token));
});
</script> </script>
<svelte:head> <svelte:head>
......
<script> <script>
import { goto } from '$app/navigation';
import { createNewTool, getTools } from '$lib/apis/tools';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation';
import { functions, models } from '$lib/stores';
import { createNewFunction, getFunctions } from '$lib/apis/functions';
import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import { getModels } from '$lib/apis';
let mounted = false; let mounted = false;
let clone = false; let clone = false;
let tool = null; let func = null;
const saveHandler = async (data) => { const saveHandler = async (data) => {
console.log(data); console.log(data);
const res = await createNewTool(localStorage.token, { const res = await createNewFunction(localStorage.token, {
id: data.id, id: data.id,
name: data.name, name: data.name,
meta: data.meta, meta: data.meta,
...@@ -23,19 +25,20 @@ ...@@ -23,19 +25,20 @@
}); });
if (res) { if (res) {
toast.success('Tool created successfully'); toast.success('Function created successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
await goto('/workspace/tools'); await goto('/workspace/functions');
} }
}; };
onMount(() => { onMount(() => {
if (sessionStorage.tool) { if (sessionStorage.function) {
tool = JSON.parse(sessionStorage.tool); func = JSON.parse(sessionStorage.function);
sessionStorage.removeItem('tool'); sessionStorage.removeItem('function');
console.log(tool); console.log(func);
clone = true; clone = true;
} }
...@@ -44,11 +47,11 @@ ...@@ -44,11 +47,11 @@
</script> </script>
{#if mounted} {#if mounted}
<ToolkitEditor <FunctionEditor
id={tool?.id ?? ''} id={func?.id ?? ''}
name={tool?.name ?? ''} name={func?.name ?? ''}
meta={tool?.meta ?? { description: '' }} meta={func?.meta ?? { description: '' }}
content={tool?.content ?? ''} content={func?.content ?? ''}
{clone} {clone}
on:save={(e) => { on:save={(e) => {
saveHandler(e.detail); saveHandler(e.detail);
......
<script> <script>
import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { getToolById, getTools, updateToolById } from '$lib/apis/tools'; import { functions, models } from '$lib/stores';
import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions';
import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte'; import { getModels } from '$lib/apis';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
let tool = null; let func = null;
const saveHandler = async (data) => { const saveHandler = async (data) => {
console.log(data); console.log(data);
const res = await updateToolById(localStorage.token, tool.id, { const res = await updateFunctionById(localStorage.token, func.id, {
id: data.id, id: data.id,
name: data.name, name: data.name,
meta: data.meta, meta: data.meta,
...@@ -23,10 +26,9 @@ ...@@ -23,10 +26,9 @@
}); });
if (res) { if (res) {
toast.success('Tool updated successfully'); toast.success('Function updated successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
// await goto('/workspace/tools');
} }
}; };
...@@ -35,24 +37,24 @@ ...@@ -35,24 +37,24 @@
const id = $page.url.searchParams.get('id'); const id = $page.url.searchParams.get('id');
if (id) { if (id) {
tool = await getToolById(localStorage.token, id).catch((error) => { func = await getFunctionById(localStorage.token, id).catch((error) => {
toast.error(error); toast.error(error);
goto('/workspace/tools'); goto('/workspace/functions');
return null; return null;
}); });
console.log(tool); console.log(func);
} }
}); });
</script> </script>
{#if tool} {#if func}
<ToolkitEditor <FunctionEditor
edit={true} edit={true}
id={tool.id} id={func.id}
name={tool.name} name={func.name}
meta={tool.meta} meta={func.meta}
content={tool.content} content={func.content}
on:save={(e) => { on:save={(e) => {
saveHandler(e.detail); saveHandler(e.detail);
}} }}
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { settings, user, config, models, tools } from '$lib/stores'; import { settings, user, config, models, tools, functions } from '$lib/stores';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
import { getModelInfos, updateModelById } from '$lib/apis/models'; import { getModelInfos, updateModelById } from '$lib/apis/models';
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import Tags from '$lib/components/common/Tags.svelte'; import Tags from '$lib/components/common/Tags.svelte';
import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte'; import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte'; import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -62,6 +63,7 @@ ...@@ -62,6 +63,7 @@
let knowledge = []; let knowledge = [];
let toolIds = []; let toolIds = [];
let filterIds = [];
const updateHandler = async () => { const updateHandler = async () => {
loading = true; loading = true;
...@@ -86,6 +88,14 @@ ...@@ -86,6 +88,14 @@
} }
} }
if (filterIds.length > 0) {
info.meta.filterIds = filterIds;
} else {
if (info.meta.filterIds) {
delete info.meta.filterIds;
}
}
info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null; info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
Object.keys(info.params).forEach((key) => { Object.keys(info.params).forEach((key) => {
if (info.params[key] === '' || info.params[key] === null) { if (info.params[key] === '' || info.params[key] === null) {
...@@ -147,6 +157,10 @@ ...@@ -147,6 +157,10 @@
toolIds = [...model?.info?.meta?.toolIds]; toolIds = [...model?.info?.meta?.toolIds];
} }
if (model?.info?.meta?.filterIds) {
filterIds = [...model?.info?.meta?.filterIds];
}
if (model?.owned_by === 'openai') { if (model?.owned_by === 'openai') {
capabilities.usage = false; capabilities.usage = false;
} }
...@@ -534,6 +548,13 @@ ...@@ -534,6 +548,13 @@
<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} /> <ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
</div> </div>
<div class="my-2">
<FiltersSelector
bind:selectedFilterIds={filterIds}
filters={$functions.filter((func) => func.type === 'filter')}
/>
</div>
<div class="my-2"> <div class="my-2">
<div class="flex w-full justify-between mb-1"> <div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div> <div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment