Unverified Commit 67a5020c authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #2505 from open-webui/dev-models

feat: openai api compatible model presets (profiles/modelfiles)
parents d0d76e2a a6af20e1
...@@ -18,8 +18,9 @@ import requests ...@@ -18,8 +18,9 @@ import requests
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import Optional, List from typing import Optional, List
from apps.web.models.models import Models
from utils.utils import get_verified_user, get_current_user, get_admin_user from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV from config import SRC_LOG_LEVELS
from constants import MESSAGES from constants import MESSAGES
import os import os
...@@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file: ...@@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.ENABLE = ENABLE_LITELLM app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config app.state.CONFIG = litellm_config
...@@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)): ...@@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)):
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
"custom_info": next(
(
item
for item in app.state.MODEL_CONFIG
if item.id == model["model_name"]
),
None,
),
} }
for model in app.state.CONFIG["model_list"] for model in app.state.CONFIG["model_list"]
], ],
......
...@@ -29,7 +29,7 @@ import time ...@@ -29,7 +29,7 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from apps.web.models.models import Models
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
...@@ -39,6 +39,8 @@ from utils.utils import ( ...@@ -39,6 +39,8 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from utils.models import get_model_id_from_custom_model_id
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -68,7 +70,6 @@ app.state.config = AppConfig() ...@@ -68,7 +70,6 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -875,14 +876,93 @@ async def generate_chat_completion( ...@@ -875,14 +876,93 @@ async def generate_chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: log.debug(
model = form_data.model "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
if ":" not in model: payload = {
model = f"{model}:latest" **form_data.model_dump(exclude_none=True),
}
if model in app.state.MODELS: model_id = form_data.model
url_idx = random.choice(app.state.MODELS[model]["urls"]) model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -892,16 +972,12 @@ async def generate_chat_completion( ...@@ -892,16 +972,12 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None print(payload)
log.debug( r = None
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request(): def get_request():
nonlocal form_data nonlocal payload
nonlocal r nonlocal r
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
...@@ -910,7 +986,7 @@ async def generate_chat_completion( ...@@ -910,7 +986,7 @@ async def generate_chat_completion(
def stream_content(): def stream_content():
try: try:
if form_data.stream: if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n" yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
...@@ -928,7 +1004,7 @@ async def generate_chat_completion( ...@@ -928,7 +1004,7 @@ async def generate_chat_completion(
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/chat", url=f"{url}/api/chat",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(payload),
stream=True, stream=True,
) )
...@@ -984,14 +1060,62 @@ async def generate_openai_chat_completion( ...@@ -984,14 +1060,62 @@ async def generate_openai_chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: payload = {
model = form_data.model **form_data.model_dump(exclude_none=True),
}
if ":" not in model: model_id = form_data.model
model = f"{model}:latest" model_info = Models.get_model_by_id(model_id)
if model in app.state.MODELS: if model_info:
url_idx = random.choice(app.state.MODELS[model]["urls"]) print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion( ...@@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion(
r = None r = None
def get_request(): def get_request():
nonlocal form_data nonlocal payload
nonlocal r nonlocal r
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
...@@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion( ...@@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion(
def stream_content(): def stream_content():
try: try:
if form_data.stream: if payload.get("stream"):
yield json.dumps( yield json.dumps(
{"request_id": request_id, "done": False} {"request_id": request_id, "done": False}
) + "\n" ) + "\n"
...@@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion( ...@@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion(
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/v1/chat/completions", url=f"{url}/v1/chat/completions",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(payload),
stream=True, stream=True,
) )
......
...@@ -10,7 +10,7 @@ import logging ...@@ -10,7 +10,7 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
from apps.web.models.models import Models
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
...@@ -53,7 +53,6 @@ app.state.config = AppConfig() ...@@ -53,7 +53,6 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
...@@ -206,7 +205,13 @@ def merge_models_lists(model_lists): ...@@ -206,7 +205,13 @@ def merge_models_lists(model_lists):
if models is not None and "error" not in models: if models is not None and "error" not in models:
merged_list.extend( merged_list.extend(
[ [
{**model, "urlIdx": idx} {
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models for model in models
if "api.openai.com" if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx] not in app.state.config.OPENAI_API_BASE_URLS[idx]
...@@ -252,7 +257,7 @@ async def get_all_models(): ...@@ -252,7 +257,7 @@ async def get_all_models():
log.info(f"models: {models}") log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]} app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models
@app.get("/models") @app.get("/models")
...@@ -310,39 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -310,39 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body() body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI # TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try: try:
body = body.decode("utf-8") if "chat/completions" in path:
body = json.loads(body) body = body.decode("utf-8")
body = json.loads(body)
model = app.state.MODELS[body.get("model")] payload = {**body}
idx = model["urlIdx"] model_id = body.get("model")
model_info = Models.get_model_by_id(model_id)
if "pipeline" in model and model.get("pipeline"): if model_info:
body["user"] = {"name": user.name, "id": user.id} print(model_info)
body["title"] = ( if model_info.base_model_id:
True if body["stream"] == False and body["max_tokens"] == 50 else False payload["model"] = model_info.base_model_id
)
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
print(app.state.MODELS)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["title"] = (
True
if payload["stream"] == False and payload["max_tokens"] == 50
else False
)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body:
body["max_tokens"] = 4000
log.debug("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
# If 'num_ctx' is in the dictionary, delete it
# Leaving it there generates an error with the
# OpenAI API (Feb 2024)
del body["num_ctx"]
# Convert the modified body back to JSON
body = json.dumps(body)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
...@@ -361,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -361,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = requests.request( r = requests.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=body, data=payload if payload else body,
headers=headers, headers=headers,
stream=True, stream=True,
) )
......
import json
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.db_url import connect from playhouse.db_url import connect
...@@ -8,6 +10,16 @@ import logging ...@@ -8,6 +10,16 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
# Check if the file exists # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
......
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Model(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
base_model_id = pw.TextField(null=True)
name = pw.TextField()
meta = pw.TextField()
params = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "model"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("model")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
import json
from utils.misc import parse_ollama_modelfile
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
modelfiles = ModelFile.select()
for modelfile in modelfiles:
# Extract and transform data in Python
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": True},
}
)
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
# Insert the processed data into the 'model' table
Model.create(
id=f"ollama-{modelfile.tag_name}",
user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"),
name=modelfile.modelfile.get("title"),
meta=meta,
params=json.dumps(info.get("params", {})),
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
query = """
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator.sql(query)
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
models = Model.select()
for model in models:
# Extract and transform data in Python
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
}
# Insert the processed data back into the 'modelfile' table
Modelfile.create(
user_id=model.user_id,
tag_name=model.id,
modelfile=modelfile_data,
timestamp=model.created_at,
)
...@@ -6,7 +6,7 @@ from apps.web.routers import ( ...@@ -6,7 +6,7 @@ from apps.web.routers import (
users, users,
chats, chats,
documents, documents,
modelfiles, models,
prompts, prompts,
configs, configs,
memories, memories,
...@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS ...@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
...@@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) ...@@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
################################################################################
# DEPRECATION NOTICE #
# #
# This file has been deprecated since version 0.2.0. #
# #
################################################################################
from pydantic import BaseModel from pydantic import BaseModel
from peewee import * from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
......
import json
import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
import time
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Models DB Schema
####################
# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
description: Optional[str] = None
"""
User-facing description of the model.
"""
capabilities: Optional[dict] = None
model_config = ConfigDict(extra="allow")
pass
class Model(pw.Model):
id = pw.TextField(unique=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = pw.TextField()
base_model_id = pw.TextField(null=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = pw.TextField()
"""
The human-readable display name of the model.
"""
params = JSONField()
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = JSONField()
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
meta: ModelMeta
params: ModelParams
class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Model.create(**model.model_dump())
if result:
return model
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except Exception as e:
print(e)
return None
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
return True
except:
return False
Models = ModelsTable(DB)
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.modelfiles import (
Modelfiles,
ModelfileForm,
ModelfileTagNameForm,
ModelfileUpdateForm,
ModelfileResponse,
)
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetModelfiles
############################
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
############################
# CreateNewModelfile
############################
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelfileByTagName
############################
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelfileByTagName
############################
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelfileByTagName
############################
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
###########################
# getModels
###########################
@router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models()
############################
# AddNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelById
############################
@router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
print(model)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# DeleteModelById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result
...@@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum): ...@@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
INVALID_TOKEN = ( INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again." "Your session has expired or the token is invalid. Please sign in again."
......
...@@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException ...@@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.litellm.main import ( from apps.litellm.main import (
app as litellm_app, app as litellm_app,
...@@ -36,10 +36,10 @@ from apps.web.main import app as webui_app ...@@ -36,10 +36,10 @@ from apps.web.main import app as webui_app
import asyncio import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List, Optional
from apps.web.models.models import Models, ModelModel
from utils.utils import get_admin_user from utils.utils import get_admin_user, get_verified_user
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages
from config import ( from config import (
...@@ -53,6 +53,8 @@ from config import ( ...@@ -53,6 +53,8 @@ from config import (
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
ENABLE_OPENAI_API,
ENABLE_OLLAMA_API,
ENABLE_LITELLM, ENABLE_LITELLM,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
...@@ -110,11 +112,19 @@ app = FastAPI( ...@@ -110,11 +112,19 @@ app = FastAPI(
) )
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
origins = ["*"] origins = ["*"]
...@@ -231,6 +241,11 @@ app.add_middleware( ...@@ -231,6 +241,11 @@ app.add_middleware(
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
start_time = int(time.time()) start_time = int(time.time())
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
...@@ -247,9 +262,11 @@ async def update_embedding_function(request: Request, call_next): ...@@ -247,9 +262,11 @@ async def update_embedding_function(request: Request, call_next):
return response return response
# TODO: Deprecate LiteLLM
app.mount("/litellm/api", litellm_app) app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app) app.mount("/ollama", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai", openai_app)
app.mount("/images/api/v1", images_app) app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app) app.mount("/audio/api/v1", audio_app)
...@@ -260,6 +277,87 @@ app.mount("/api/v1", webui_app) ...@@ -260,6 +277,87 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
openai_models = []
ollama_models = []
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models()
ollama_models = [
{
"id": model["model"],
"name": model["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
}
for model in ollama_models["models"]
]
models = openai_models + ollama_models
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
if (
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
):
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
else:
owned_by = "openai"
for model in models:
if (
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
break
models.append(
{
"id": custom_model.id,
"name": custom_model.name,
"object": "model",
"created": custom_model.created_at,
"owned_by": owned_by,
"info": custom_model.model_dump(),
"preset": True,
}
)
app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS
return models
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models = list(
filter(
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models,
)
)
return {"data": models}
return {"data": models}
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA # Checking and Handling the Absence of 'ui' in CONFIG_DATA
......
from pathlib import Path from pathlib import Path
import hashlib import hashlib
import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional from typing import Optional
...@@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]: ...@@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]:
total_duration += timedelta(weeks=number) total_duration += timedelta(weeks=number)
return total_duration return total_duration
def parse_ollama_modelfile(model_text):
parameters_meta = {
"mirostat": int,
"mirostat_eta": float,
"mirostat_tau": float,
"num_ctx": int,
"repeat_last_n": int,
"repeat_penalty": float,
"temperature": float,
"seed": int,
"stop": str,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
}
data = {"base_model_id": None, "params": {}}
# Parse base model
base_model_match = re.search(
r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
)
if base_model_match:
data["base_model_id"] = base_model_match.group(1)
# Parse template
template_match = re.search(
r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
if template_match:
data["params"] = {"template": template_match.group(1).strip()}
# Parse stops
stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
if stops:
data["params"]["stop"] = stops
# Parse other parameters from the provided list
for param, param_type in parameters_meta.items():
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
data["params"][param] = value
# Parse adapter
adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
if adapter_match:
data["params"]["adapter"] = adapter_match.group(1)
# Parse system description
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
# Parse messages
messages = []
message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
for role, content in message_matches:
messages.append({"role": role, "content": content})
if messages:
data["params"]["messages"] = messages
return data
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
export const getModels = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/models`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
let models = res?.data ?? [];
models = models
.filter((models) => models)
.sort((a, b) => {
// Compare case-insensitively
const lowerA = a.name.toLowerCase();
const lowerB = b.name.toLowerCase();
if (lowerA < lowerB) return -1;
if (lowerA > lowerB) return 1;
// If same case-insensitively, sort by original strings,
// lowercase will come before uppercase due to ASCII values
if (a < b) return -1;
if (a > b) return 1;
return 0; // They are equal
});
console.log(models);
return models;
};
export const getBackendConfig = async () => { export const getBackendConfig = async () => {
let error = null; let error = null;
...@@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => { ...@@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => {
return res.url; return res.url;
}; };
export const getModelConfig = async (token: string): Promise<GlobalModelConfig> => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res.models;
};
export interface ModelConfig {
id: string;
name: string;
meta: ModelMeta;
base_model_id?: string;
params: ModelParams;
}
export interface ModelMeta {
description?: string;
capabilities?: object;
}
export interface ModelParams {}
export type GlobalModelConfig = ModelConfig[];
export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
models: config
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => { ...@@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => {
id: model.id, id: model.id,
name: model.name ?? model.id, name: model.name ?? model.id,
external: true, external: true,
source: 'LiteLLM' source: 'LiteLLM',
custom_info: model.custom_info
})) }))
.sort((a, b) => { .sort((a, b) => {
return a.name.localeCompare(b.name); return a.name.localeCompare(b.name);
......
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewModelfile = async (token: string, modelfile: object) => { export const addNewModel = async (token: string, model: object) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(model)
modelfile: modelfile
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { ...@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => {
return res; return res;
}; };
export const getModelfiles = async (token: string = '') => { export const getModelInfos = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
...@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => { ...@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => {
throw error; throw error;
} }
return res.map((modelfile) => modelfile.modelfile); return res;
}; };
export const getModelfileByTagName = async (token: string, tagName: string) => { export const getModelById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, {
method: 'POST', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, }
body: JSON.stringify({
tag_name: tagName
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => { ...@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => {
throw error; throw error;
} }
return res.modelfile; return res;
}; };
export const updateModelfileByTagName = async ( export const updateModelById = async (token: string, id: string, model: object) => {
token: string,
tagName: string,
modelfile: object
) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(model)
tag_name: tagName,
modelfile: modelfile
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async ( ...@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async (
return res; return res;
}; };
export const deleteModelfileByTagName = async (token: string, tagName: string) => { export const deleteModelById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/delete`, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, }
body: JSON.stringify({
tag_name: tagName
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
......
...@@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => { ...@@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => {
return models return models
? models ? models
.map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) .map((model) => ({
id: model.id,
name: model.name ?? model.id,
external: true,
custom_info: model.custom_info
}))
.sort((a, b) => { .sort((a, b) => {
return a.name.localeCompare(b.name); return a.name.localeCompare(b.name);
}) })
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
chatId, chatId,
chats, chats,
config, config,
modelfiles, type Model,
models, models,
settings, settings,
showSidebar, showSidebar,
...@@ -60,25 +60,7 @@ ...@@ -60,25 +60,7 @@
let showModelSelector = true; let showModelSelector = true;
let selectedModels = ['']; let selectedModels = [''];
let atSelectedModel = ''; let atSelectedModel: Model | undefined;
let selectedModelfile = null;
$: selectedModelfile =
selectedModels.length === 1 &&
$modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0]).length > 0
? $modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0])[0]
: null;
let selectedModelfiles = {};
$: selectedModelfiles = selectedModels.reduce((a, tagName, i, arr) => {
const modelfile =
$modelfiles.filter((modelfile) => modelfile.tagName === tagName)?.at(0) ?? undefined;
return {
...a,
...(modelfile && { [tagName]: modelfile })
};
}, {});
let chat = null; let chat = null;
let tags = []; let tags = [];
...@@ -164,6 +146,7 @@ ...@@ -164,6 +146,7 @@
if ($page.url.searchParams.get('q')) { if ($page.url.searchParams.get('q')) {
prompt = $page.url.searchParams.get('q') ?? ''; prompt = $page.url.searchParams.get('q') ?? '';
if (prompt) { if (prompt) {
await tick(); await tick();
submitPrompt(prompt); submitPrompt(prompt);
...@@ -211,7 +194,7 @@ ...@@ -211,7 +194,7 @@
await settings.set({ await settings.set({
..._settings, ..._settings,
system: chatContent.system ?? _settings.system, system: chatContent.system ?? _settings.system,
options: chatContent.options ?? _settings.options params: chatContent.options ?? _settings.params
}); });
autoScroll = true; autoScroll = true;
await tick(); await tick();
...@@ -300,7 +283,7 @@ ...@@ -300,7 +283,7 @@
models: selectedModels, models: selectedModels,
system: $settings.system ?? undefined, system: $settings.system ?? undefined,
options: { options: {
...($settings.options ?? {}) ...($settings.params ?? {})
}, },
messages: messages, messages: messages,
history: history, history: history,
...@@ -317,6 +300,7 @@ ...@@ -317,6 +300,7 @@
// Reset chat input textarea // Reset chat input textarea
prompt = ''; prompt = '';
document.getElementById('chat-textarea').style.height = '';
files = []; files = [];
// Send prompt // Send prompt
...@@ -328,75 +312,92 @@ ...@@ -328,75 +312,92 @@
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
await Promise.all( await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( (modelId
async (modelId) => { ? [modelId]
console.log('modelId', modelId); : atSelectedModel !== undefined
const model = $models.filter((m) => m.id === modelId).at(0); ? [atSelectedModel.id]
: selectedModels
if (model) { ).map(async (modelId) => {
// Create response message console.log('modelId', modelId);
let responseMessageId = uuidv4(); const model = $models.filter((m) => m.id === modelId).at(0);
let responseMessage = {
parentId: parentId, if (model) {
id: responseMessageId, // If there are image files, check if model is vision capable
childrenIds: [], const hasImages = messages.some((message) =>
role: 'assistant', message.files?.some((file) => file.type === 'image')
content: '', );
model: model.id,
userContext: null,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch
};
// Add message to history and Set currentId to messageId
history.messages[responseMessageId] = responseMessage;
history.currentId = responseMessageId;
// Append messageId to childrenIds of parent message
if (parentId !== null) {
history.messages[parentId].childrenIds = [
...history.messages[parentId].childrenIds,
responseMessageId
];
}
await tick(); if (hasImages && !(model.info?.meta?.capabilities?.vision ?? true)) {
toast.error(
$i18n.t('Model {{modelName}} is not vision capable', {
modelName: model.name ?? model.id
})
);
}
let userContext = null; // Create response message
if ($settings?.memory ?? false) { let responseMessageId = uuidv4();
if (userContext === null) { let responseMessage = {
const res = await queryMemory(localStorage.token, prompt).catch((error) => { parentId: parentId,
toast.error(error); id: responseMessageId,
return null; childrenIds: [],
}); role: 'assistant',
content: '',
if (res) { model: model.id,
if (res.documents[0].length > 0) { modelName: model.name ?? model.id,
userContext = res.documents.reduce((acc, doc, index) => { userContext: null,
const createdAtTimestamp = res.metadatas[index][0].created_at; timestamp: Math.floor(Date.now() / 1000) // Unix epoch
const createdAtDate = new Date(createdAtTimestamp * 1000) };
.toISOString()
.split('T')[0]; // Add message to history and Set currentId to messageId
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); history.messages[responseMessageId] = responseMessage;
return acc; history.currentId = responseMessageId;
}, []);
} // Append messageId to childrenIds of parent message
if (parentId !== null) {
history.messages[parentId].childrenIds = [
...history.messages[parentId].childrenIds,
responseMessageId
];
}
console.log(userContext); await tick();
let userContext = null;
if ($settings?.memory ?? false) {
if (userContext === null) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
if (res.documents[0].length > 0) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000)
.toISOString()
.split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
} }
console.log(userContext);
} }
} }
responseMessage.userContext = userContext; }
responseMessage.userContext = userContext;
if (model?.external) { if (model?.owned_by === 'openai') {
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} else if (model) { } else if (model) {
await sendPromptOllama(model, prompt, responseMessageId, _chatId); await sendPromptOllama(model, prompt, responseMessageId, _chatId);
}
} else {
toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
} }
} else {
toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
} }
) })
); );
await chats.set(await getChatList(localStorage.token)); await chats.set(await getChatList(localStorage.token));
...@@ -430,7 +431,7 @@ ...@@ -430,7 +431,7 @@
// Prepare the base message object // Prepare the base message object
const baseMessage = { const baseMessage = {
role: message.role, role: message.role,
content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content content: message.content
}; };
// Extract and format image URLs if any exist // Extract and format image URLs if any exist
...@@ -442,7 +443,6 @@ ...@@ -442,7 +443,6 @@
if (imageUrls && imageUrls.length > 0 && message.role === 'user') { if (imageUrls && imageUrls.length > 0 && message.role === 'user') {
baseMessage.images = imageUrls; baseMessage.images = imageUrls;
} }
return baseMessage; return baseMessage;
}); });
...@@ -473,13 +473,15 @@ ...@@ -473,13 +473,15 @@
model: model, model: model,
messages: messagesBody, messages: messagesBody,
options: { options: {
...($settings.options ?? {}), ...($settings.params ?? {}),
stop: stop:
$settings?.options?.stop ?? undefined $settings?.params?.stop ?? undefined
? $settings.options.stop.map((str) => ? $settings.params.stop.map((str) =>
decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
) )
: undefined : undefined,
num_predict: $settings?.params?.max_tokens ?? undefined,
repeat_penalty: $settings?.params?.frequency_penalty ?? undefined
}, },
format: $settings.requestFormat ?? undefined, format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
...@@ -605,7 +607,8 @@ ...@@ -605,7 +607,8 @@
if ($settings.saveChatHistory ?? true) { if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, { chat = await updateChatById(localStorage.token, _chatId, {
messages: messages, messages: messages,
history: history history: history,
models: selectedModels
}); });
await chats.set(await getChatList(localStorage.token)); await chats.set(await getChatList(localStorage.token));
} }
...@@ -716,18 +719,17 @@ ...@@ -716,18 +719,17 @@
: message?.raContent ?? message.content : message?.raContent ?? message.content
}) })
})), })),
seed: $settings?.options?.seed ?? undefined, seed: $settings?.params?.seed ?? undefined,
stop: stop:
$settings?.options?.stop ?? undefined $settings?.params?.stop ?? undefined
? $settings.options.stop.map((str) => ? $settings.params.stop.map((str) =>
decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
) )
: undefined, : undefined,
temperature: $settings?.options?.temperature ?? undefined, temperature: $settings?.params?.temperature ?? undefined,
top_p: $settings?.options?.top_p ?? undefined, top_p: $settings?.params?.top_p ?? undefined,
num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs.length > 0 ? docs : undefined, docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0 citations: docs.length > 0
}, },
...@@ -797,6 +799,7 @@ ...@@ -797,6 +799,7 @@
if ($chatId == _chatId) { if ($chatId == _chatId) {
if ($settings.saveChatHistory ?? true) { if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, { chat = await updateChatById(localStorage.token, _chatId, {
models: selectedModels,
messages: messages, messages: messages,
history: history history: history
}); });
...@@ -935,10 +938,8 @@ ...@@ -935,10 +938,8 @@
) + ' {{prompt}}', ) + ' {{prompt}}',
titleModelId, titleModelId,
userPrompt, userPrompt,
titleModel?.external ?? false titleModel?.owned_by === 'openai' ?? false
? titleModel?.source?.toLowerCase() === 'litellm' ? `${OPENAI_API_BASE_URL}`
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1` : `${OLLAMA_API_BASE_URL}/v1`
); );
...@@ -1025,16 +1026,12 @@ ...@@ -1025,16 +1026,12 @@
<Messages <Messages
chatId={$chatId} chatId={$chatId}
{selectedModels} {selectedModels}
{selectedModelfiles}
{processing} {processing}
bind:history bind:history
bind:messages bind:messages
bind:autoScroll bind:autoScroll
bind:prompt bind:prompt
bottomPadding={files.length > 0} bottomPadding={files.length > 0}
suggestionPrompts={chatIdProp
? []
: selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
{sendPrompt} {sendPrompt}
{continueGeneration} {continueGeneration}
{regenerateResponse} {regenerateResponse}
...@@ -1048,7 +1045,8 @@ ...@@ -1048,7 +1045,8 @@
bind:files bind:files
bind:prompt bind:prompt
bind:autoScroll bind:autoScroll
bind:selectedModel={atSelectedModel} bind:atSelectedModel
{selectedModels}
{messages} {messages}
{submitPrompt} {submitPrompt}
{stopResponse} {stopResponse}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment