"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "98948814fd28508d968b47c0ea092784874778ad"
Commit 4d57e08b authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

feat: modelfiles to models

parent 17e4be49
...@@ -6,7 +6,7 @@ from apps.web.routers import ( ...@@ -6,7 +6,7 @@ from apps.web.routers import (
users, users,
chats, chats,
documents, documents,
modelfiles, models,
prompts, prompts,
configs, configs,
memories, memories,
...@@ -56,11 +56,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) ...@@ -56,11 +56,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
################################################################################
# DEPRECATION NOTICE #
# #
# This file has been deprecated since version 0.2.0. #
# #
################################################################################
from pydantic import BaseModel from pydantic import BaseModel
from peewee import * from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
......
...@@ -3,13 +3,18 @@ import logging ...@@ -3,13 +3,18 @@ import logging
from typing import Optional from typing import Optional
import peewee as pw import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField from apps.web.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
import time
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
...@@ -20,10 +25,8 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -20,10 +25,8 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table # ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel): class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
pass pass
...@@ -55,7 +58,6 @@ class Model(pw.Model): ...@@ -55,7 +58,6 @@ class Model(pw.Model):
base_model_id = pw.TextField(null=True) base_model_id = pw.TextField(null=True)
""" """
An optional pointer to the actual model that should be used when proxying requests. An optional pointer to the actual model that should be used when proxying requests.
Currently unused - but will be used to support Modelfile like behaviour in the future
""" """
name = pw.TextField() name = pw.TextField()
...@@ -73,8 +75,8 @@ class Model(pw.Model): ...@@ -73,8 +75,8 @@ class Model(pw.Model):
Holds a JSON encoded blob of metadata, see `ModelMeta`. Holds a JSON encoded blob of metadata, see `ModelMeta`.
""" """
updated_at: int # timestamp in epoch updated_at = BigIntegerField()
created_at: int # timestamp in epoch created_at = BigIntegerField()
class Meta: class Meta:
database = DB database = DB
...@@ -83,16 +85,36 @@ class Model(pw.Model): ...@@ -83,16 +85,36 @@ class Model(pw.Model):
class ModelModel(BaseModel): class ModelModel(BaseModel):
id: str id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
name: str name: str
params: ModelParams params: ModelParams
meta: ModelMeta meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
#################### ####################
# Forms # Forms
#################### ####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
meta: ModelMeta
params: ModelParams
class ModelsTable: class ModelsTable:
def __init__( def __init__(
self, self,
...@@ -101,44 +123,47 @@ class ModelsTable: ...@@ -101,44 +123,47 @@ class ModelsTable:
self.db = db self.db = db
self.db.create_tables([Model]) self.db.create_tables([Model])
def get_all_models(self) -> list[ModelModel]: def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]:
try:
model = Model.create(
**{
**model.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
return ModelModel(**model_to_dict(model))
except:
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()] return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def update_all_models(self, models: list[ModelModel]) -> bool: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with self.db.atomic(): model = Model.get(Model.id == id)
# Fetch current models from the database return ModelModel(**model_to_dict(model))
current_models = self.get_all_models() except:
current_model_dict = {model.id: model for model in current_models} return None
# Create a set of model IDs from the current models and the new models
current_model_keys = set(current_model_dict.keys())
new_model_keys = set(model.id for model in models)
# Determine which models need to be created, updated, or deleted
models_to_create = [
model for model in models if model.id not in current_model_keys
]
models_to_update = [
model for model in models if model.id in current_model_keys
]
models_to_delete = current_model_keys - new_model_keys
# Perform the necessary database operations
for model in models_to_create:
Model.create(**model.model_dump())
for model in models_to_update:
Model.update(**model.model_dump()).where(
Model.id == model.id
).execute()
for model_id, model_source in models_to_delete:
Model.delete().where(Model.id == model_id).execute()
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except:
return None
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
return True return True
except Exception as e: except:
log.exception(e)
return False return False
......
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.modelfiles import (
Modelfiles,
ModelfileForm,
ModelfileTagNameForm,
ModelfileUpdateForm,
ModelfileResponse,
)
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetModelfiles
############################
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
############################
# CreateNewModelfile
############################
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelfileByTagName
############################
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelfileByTagName
############################
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelfileByTagName
############################
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
###########################
# getAllModels
###########################
@router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models()
############################
# AddNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(form_data: ModelForm, user=Depends(get_admin_user)):
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelById
############################
@router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result
...@@ -320,33 +320,6 @@ async def update_model_filter_config( ...@@ -320,33 +320,6 @@ async def update_model_filter_config(
} }
class SetModelConfigForm(BaseModel):
models: List[ModelModel]
@app.post("/api/config/models")
async def update_model_config(
form_data: SetModelConfigForm, user=Depends(get_admin_user)
):
if not Models.update_all_models(form_data.models):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
)
ollama_app.state.MODEL_CONFIG = form_data.models
openai_app.state.MODEL_CONFIG = form_data.models
litellm_app.state.MODEL_CONFIG = form_data.models
app.state.MODEL_CONFIG = form_data.models
return {"models": app.state.MODEL_CONFIG}
@app.get("/api/config/models")
async def get_model_config(user=Depends(get_admin_user)):
return {"models": app.state.MODEL_CONFIG}
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
......
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewModelfile = async (token: string, modelfile: object) => { export const addNewModel = async (token: string, model: object) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(model)
modelfile: modelfile
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { ...@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => {
return res; return res;
}; };
export const getModelfiles = async (token: string = '') => { export const getModels = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
...@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => { ...@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => {
throw error; throw error;
} }
return res.map((modelfile) => modelfile.modelfile); return res;
}; };
export const getModelfileByTagName = async (token: string, tagName: string) => { export const getModelById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, {
method: 'POST', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, }
body: JSON.stringify({
tag_name: tagName
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => { ...@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => {
throw error; throw error;
} }
return res.modelfile; return res;
}; };
export const updateModelfileByTagName = async ( export const updateModelById = async (token: string, id: string, model: object) => {
token: string,
tagName: string,
modelfile: object
) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(model)
tag_name: tagName,
modelfile: modelfile
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async ( ...@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async (
return res; return res;
}; };
export const deleteModelfileByTagName = async (token: string, tagName: string) => { export const deleteModelById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/delete`, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, }
body: JSON.stringify({
tag_name: tagName
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { import {
capitalizeFirstLetter, capitalizeFirstLetter,
getModels, getAllModels,
sanitizeResponseContent, sanitizeResponseContent,
splitStream splitStream
} from '$lib/utils'; } from '$lib/utils';
...@@ -159,7 +159,7 @@ ...@@ -159,7 +159,7 @@
}) })
); );
models.set(await getModels(localStorage.token)); models.set(await getAllModels(localStorage.token));
} else { } else {
toast.error($i18n.t('Download canceled')); toast.error($i18n.t('Download canceled'));
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
const i18n = getContext('i18n'); const i18n = getContext('i18n');
export let getModels: Function; export let getAllModels: Function;
// External // External
let OLLAMA_BASE_URLS = ['']; let OLLAMA_BASE_URLS = [''];
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS);
OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS); OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS);
await models.set(await getModels()); await models.set(await getAllModels());
}; };
const updateOllamaUrlsHandler = async () => { const updateOllamaUrlsHandler = async () => {
...@@ -51,7 +51,7 @@ ...@@ -51,7 +51,7 @@
if (ollamaVersion) { if (ollamaVersion) {
toast.success($i18n.t('Server connection verified')); toast.success($i18n.t('Server connection verified'));
await models.set(await getModels()); await models.set(await getAllModels());
} }
}; };
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
import AdvancedParams from './Advanced/AdvancedParams.svelte'; import AdvancedParams from './Advanced/AdvancedParams.svelte';
export let saveSettings: Function; export let saveSettings: Function;
export let getModels: Function; export let getAllModels: Function;
// General // General
let themes = ['dark', 'light', 'rose-pine dark', 'rose-pine-dawn light', 'oled-dark']; let themes = ['dark', 'light', 'rose-pine dark', 'rose-pine-dawn light', 'oled-dark'];
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
let imageSize = ''; let imageSize = '';
let steps = 50; let steps = 50;
const getModels = async () => { const getAllModels = async () => {
models = await getImageGenerationModels(localStorage.token).catch((error) => { models = await getImageGenerationModels(localStorage.token).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
if (res) { if (res) {
COMFYUI_BASE_URL = res.COMFYUI_BASE_URL; COMFYUI_BASE_URL = res.COMFYUI_BASE_URL;
await getModels(); await getAllModels();
if (models) { if (models) {
toast.success($i18n.t('Server connection verified')); toast.success($i18n.t('Server connection verified'));
...@@ -85,7 +85,7 @@ ...@@ -85,7 +85,7 @@
if (res) { if (res) {
AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL; AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL;
await getModels(); await getAllModels();
if (models) { if (models) {
toast.success($i18n.t('Server connection verified')); toast.success($i18n.t('Server connection verified'));
...@@ -112,7 +112,7 @@ ...@@ -112,7 +112,7 @@
if (enableImageGeneration) { if (enableImageGeneration) {
config.set(await getBackendConfig(localStorage.token)); config.set(await getBackendConfig(localStorage.token));
getModels(); getAllModels();
} }
}; };
...@@ -141,7 +141,7 @@ ...@@ -141,7 +141,7 @@
steps = await getImageSteps(localStorage.token); steps = await getImageSteps(localStorage.token);
if (enableImageGeneration) { if (enableImageGeneration) {
getModels(); getAllModels();
} }
} }
}); });
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
const i18n = getContext('i18n'); const i18n = getContext('i18n');
export let getModels: Function; export let getAllModels: Function;
let showLiteLLM = false; let showLiteLLM = false;
let showLiteLLMParams = false; let showLiteLLMParams = false;
...@@ -261,7 +261,7 @@ ...@@ -261,7 +261,7 @@
}) })
); );
models.set(await getModels(localStorage.token)); models.set(await getAllModels(localStorage.token));
} else { } else {
toast.error($i18n.t('Download canceled')); toast.error($i18n.t('Download canceled'));
} }
...@@ -424,7 +424,7 @@ ...@@ -424,7 +424,7 @@
modelTransferring = false; modelTransferring = false;
uploadProgress = null; uploadProgress = null;
models.set(await getModels()); models.set(await getAllModels());
}; };
const deleteModelHandler = async () => { const deleteModelHandler = async () => {
...@@ -439,7 +439,7 @@ ...@@ -439,7 +439,7 @@
} }
deleteModelTag = ''; deleteModelTag = '';
models.set(await getModels()); models.set(await getAllModels());
}; };
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
...@@ -488,7 +488,7 @@ ...@@ -488,7 +488,7 @@
liteLLMMaxTokens = ''; liteLLMMaxTokens = '';
liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
models.set(await getModels()); models.set(await getAllModels());
}; };
const deleteLiteLLMModelHandler = async () => { const deleteLiteLLMModelHandler = async () => {
...@@ -507,7 +507,7 @@ ...@@ -507,7 +507,7 @@
deleteLiteLLMModelName = ''; deleteLiteLLMModelName = '';
liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
models.set(await getModels()); models.set(await getAllModels());
}; };
const addModelInfoHandler = async () => { const addModelInfoHandler = async () => {
...@@ -519,9 +519,7 @@ ...@@ -519,9 +519,7 @@
return; return;
} }
// Remove any existing config // Remove any existing config
modelConfig = modelConfig.filter( modelConfig = modelConfig.filter((m) => !(m.id === selectedModelId));
(m) => !(m.id === selectedModelId)
);
// Add new config // Add new config
modelConfig.push({ modelConfig.push({
id: selectedModelId, id: selectedModelId,
...@@ -536,7 +534,7 @@ ...@@ -536,7 +534,7 @@
toast.success( toast.success(
$i18n.t('Model info for {{modelName}} added successfully', { modelName: selectedModelId }) $i18n.t('Model info for {{modelName}} added successfully', { modelName: selectedModelId })
); );
models.set(await getModels()); models.set(await getAllModels());
}; };
const deleteModelInfoHandler = async () => { const deleteModelInfoHandler = async () => {
...@@ -547,14 +545,12 @@ ...@@ -547,14 +545,12 @@
if (!model) { if (!model) {
return; return;
} }
modelConfig = modelConfig.filter( modelConfig = modelConfig.filter((m) => !(m.id === selectedModelId));
(m) => !(m.id === selectedModelId)
);
await updateModelConfig(localStorage.token, modelConfig); await updateModelConfig(localStorage.token, modelConfig);
toast.success( toast.success(
$i18n.t('Model info for {{modelName}} deleted successfully', { modelName: selectedModelId }) $i18n.t('Model info for {{modelName}} deleted successfully', { modelName: selectedModelId })
); );
models.set(await getModels()); models.set(await getAllModels());
}; };
const toggleIsVisionCapable = () => { const toggleIsVisionCapable = () => {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { models, settings, user } from '$lib/stores'; import { models, settings, user } from '$lib/stores';
import { getModels as _getModels } from '$lib/utils'; import { getAllModels as _getAllModels } from '$lib/utils';
import Modal from '../common/Modal.svelte'; import Modal from '../common/Modal.svelte';
import Account from './Settings/Account.svelte'; import Account from './Settings/Account.svelte';
...@@ -25,12 +25,12 @@ ...@@ -25,12 +25,12 @@
const saveSettings = async (updated) => { const saveSettings = async (updated) => {
console.log(updated); console.log(updated);
await settings.set({ ...$settings, ...updated }); await settings.set({ ...$settings, ...updated });
await models.set(await getModels()); await models.set(await getAllModels());
localStorage.setItem('settings', JSON.stringify($settings)); localStorage.setItem('settings', JSON.stringify($settings));
}; };
const getModels = async () => { const getAllModels = async () => {
return await _getModels(localStorage.token); return await _getAllModels(localStorage.token);
}; };
let selectedTab = 'general'; let selectedTab = 'general';
...@@ -318,17 +318,17 @@ ...@@ -318,17 +318,17 @@
<div class="flex-1 md:min-h-[28rem]"> <div class="flex-1 md:min-h-[28rem]">
{#if selectedTab === 'general'} {#if selectedTab === 'general'}
<General <General
{getModels} {getAllModels}
{saveSettings} {saveSettings}
on:save={() => { on:save={() => {
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
/> />
{:else if selectedTab === 'models'} {:else if selectedTab === 'models'}
<Models {getModels} /> <Models {getAllModels} />
{:else if selectedTab === 'connections'} {:else if selectedTab === 'connections'}
<Connections <Connections
{getModels} {getAllModels}
on:save={() => { on:save={() => {
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
......
...@@ -7,11 +7,7 @@ ...@@ -7,11 +7,7 @@
import { WEBUI_NAME, modelfiles, settings, user } from '$lib/stores'; import { WEBUI_NAME, modelfiles, settings, user } from '$lib/stores';
import { createModel, deleteModel } from '$lib/apis/ollama'; import { createModel, deleteModel } from '$lib/apis/ollama';
import { import { addNewModel, deleteModelById, getModels } from '$lib/apis/models';
createNewModelfile,
deleteModelfileByTagName,
getModelfiles
} from '$lib/apis/modelfiles';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -36,8 +32,8 @@ ...@@ -36,8 +32,8 @@
const deleteModelfile = async (tagName) => { const deleteModelfile = async (tagName) => {
await deleteModelHandler(tagName); await deleteModelHandler(tagName);
await deleteModelfileByTagName(localStorage.token, tagName); await deleteModelById(localStorage.token, tagName);
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}; };
const shareModelfile = async (modelfile) => { const shareModelfile = async (modelfile) => {
...@@ -246,12 +242,12 @@ ...@@ -246,12 +242,12 @@
console.log(savedModelfiles); console.log(savedModelfiles);
for (const modelfile of savedModelfiles) { for (const modelfile of savedModelfiles) {
await createNewModelfile(localStorage.token, modelfile).catch((error) => { await addNewModel(localStorage.token, modelfile).catch((error) => {
return null; return null;
}); });
} }
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}; };
reader.readAsText(importFiles[0]); reader.readAsText(importFiles[0]);
...@@ -318,7 +314,7 @@ ...@@ -318,7 +314,7 @@
class="self-center w-fit text-sm px-3 py-1 border dark:border-gray-600 rounded-xl flex" class="self-center w-fit text-sm px-3 py-1 border dark:border-gray-600 rounded-xl flex"
on:click={async () => { on:click={async () => {
for (const modelfile of localModelfiles) { for (const modelfile of localModelfiles) {
await createNewModelfile(localStorage.token, modelfile).catch((error) => { await addNewModel(localStorage.token, modelfile).catch((error) => {
return null; return null;
}); });
} }
...@@ -326,7 +322,7 @@ ...@@ -326,7 +322,7 @@
saveModelfiles(localModelfiles); saveModelfiles(localModelfiles);
localStorage.removeItem('modelfiles'); localStorage.removeItem('modelfiles');
localModelfiles = JSON.parse(localStorage.getItem('modelfiles') ?? '[]'); localModelfiles = JSON.parse(localStorage.getItem('modelfiles') ?? '[]');
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}} }}
> >
<div class=" self-center mr-2 font-medium">{$i18n.t('Sync All')}</div> <div class=" self-center mr-2 font-medium">{$i18n.t('Sync All')}</div>
...@@ -354,7 +350,7 @@ ...@@ -354,7 +350,7 @@
localStorage.removeItem('modelfiles'); localStorage.removeItem('modelfiles');
localModelfiles = JSON.parse(localStorage.getItem('modelfiles') ?? '[]'); localModelfiles = JSON.parse(localStorage.getItem('modelfiles') ?? '[]');
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}} }}
> >
<div class=" self-center"> <div class=" self-center">
......
...@@ -4,7 +4,7 @@ import { getOllamaModels } from '$lib/apis/ollama'; ...@@ -4,7 +4,7 @@ import { getOllamaModels } from '$lib/apis/ollama';
import { getOpenAIModels } from '$lib/apis/openai'; import { getOpenAIModels } from '$lib/apis/openai';
import { getLiteLLMModels } from '$lib/apis/litellm'; import { getLiteLLMModels } from '$lib/apis/litellm';
export const getModels = async (token: string) => { export const getAllModels = async (token: string) => {
let models = await Promise.all([ let models = await Promise.all([
getOllamaModels(token).catch((error) => { getOllamaModels(token).catch((error) => {
console.log(error); console.log(error);
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { getModels as _getModels } from '$lib/utils'; import { getAllModels as _getAllModels } from '$lib/utils';
import { getOllamaVersion } from '$lib/apis/ollama'; import { getOllamaVersion } from '$lib/apis/ollama';
import { getModelfiles } from '$lib/apis/modelfiles'; import { getModels } from '$lib/apis/models';
import { getPrompts } from '$lib/apis/prompts'; import { getPrompts } from '$lib/apis/prompts';
import { getDocs } from '$lib/apis/documents'; import { getDocs } from '$lib/apis/documents';
...@@ -46,8 +46,8 @@ ...@@ -46,8 +46,8 @@
let showShortcuts = false; let showShortcuts = false;
const getModels = async () => { const getAllModels = async () => {
return _getModels(localStorage.token); return _getAllModels(localStorage.token);
}; };
const setOllamaVersion = async (version: string = '') => { const setOllamaVersion = async (version: string = '') => {
...@@ -91,10 +91,10 @@ ...@@ -91,10 +91,10 @@
await Promise.all([ await Promise.all([
(async () => { (async () => {
models.set(await getModels()); models.set(await getAllModels());
})(), })(),
(async () => { (async () => {
modelfiles.set(await getModelfiles(localStorage.token)); modelfiles.set(await getModels(localStorage.token));
})(), })(),
(async () => { (async () => {
prompts.set(await getPrompts(localStorage.token)); prompts.set(await getPrompts(localStorage.token));
...@@ -109,7 +109,7 @@ ...@@ -109,7 +109,7 @@
modelfiles.subscribe(async () => { modelfiles.subscribe(async () => {
// should fetch models // should fetch models
models.set(await getModels()); models.set(await getAllModels());
}); });
document.addEventListener('keydown', function (event) { document.addEventListener('keydown', function (event) {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
import { onMount, tick, getContext } from 'svelte'; import { onMount, tick, getContext } from 'svelte';
import { createModel } from '$lib/apis/ollama'; import { createModel } from '$lib/apis/ollama';
import { createNewModelfile, getModelfileByTagName, getModelfiles } from '$lib/apis/modelfiles'; import { addNewModel, getModelById, getModels } from '$lib/apis/models';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -98,8 +98,8 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, ''); ...@@ -98,8 +98,8 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, '');
}; };
const saveModelfile = async (modelfile) => { const saveModelfile = async (modelfile) => {
await createNewModelfile(localStorage.token, modelfile); await addNewModel(localStorage.token, modelfile);
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}; };
const submitHandler = async () => { const submitHandler = async () => {
...@@ -116,7 +116,7 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, ''); ...@@ -116,7 +116,7 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, '');
if ( if (
$models.map((model) => model.name).includes(tagName) || $models.map((model) => model.name).includes(tagName) ||
(await getModelfileByTagName(localStorage.token, tagName).catch(() => false)) (await getModelById(localStorage.token, tagName).catch(() => false))
) { ) {
toast.error( toast.error(
`Uh-oh! It looks like you already have a model named '${tagName}'. Please choose a different name to complete your modelfile.` `Uh-oh! It looks like you already have a model named '${tagName}'. Please choose a different name to complete your modelfile.`
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
import { createModel } from '$lib/apis/ollama'; import { createModel } from '$lib/apis/ollama';
import { getModelfiles, updateModelfileByTagName } from '$lib/apis/modelfiles'; import { getModels, updateModelById } from '$lib/apis/models';
import AdvancedParams from '$lib/components/chat/Settings/Advanced/AdvancedParams.svelte'; import AdvancedParams from '$lib/components/chat/Settings/Advanced/AdvancedParams.svelte';
...@@ -85,8 +85,8 @@ ...@@ -85,8 +85,8 @@
}); });
const updateModelfile = async (modelfile) => { const updateModelfile = async (modelfile) => {
await updateModelfileByTagName(localStorage.token, modelfile.tagName, modelfile); await updateModelById(localStorage.token, modelfile.tagName, modelfile);
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModels(localStorage.token));
}; };
const updateHandler = async () => { const updateHandler = async () => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment