Unverified Commit 371dfc11 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge branch 'dev' into debug_print

parents f74f2ea7 a1faa307
...@@ -166,7 +166,7 @@ cython_debug/ ...@@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
# Logs # Logs
logs logs
......
...@@ -18,6 +18,8 @@ from utils.utils import ( ...@@ -18,6 +18,8 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
) )
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
...@@ -27,7 +29,8 @@ import base64 ...@@ -27,7 +29,8 @@ import base64
import json import json
import logging import logging
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
...@@ -52,6 +55,8 @@ app.state.MODEL = "" ...@@ -52,6 +55,8 @@ app.state.MODEL = ""
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_STEPS = 50 app.state.IMAGE_STEPS = 50
...@@ -74,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user ...@@ -74,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class UrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
url: str AUTOMATIC1111_BASE_URL: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
@app.get("/url") @app.get("/url")
async def get_automatic1111_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
}
@app.post("/url/update") @app.post("/url/update")
async def update_automatic1111_url( async def update_engine_url(
form_data: UrlUpdateForm, user=Depends(get_admin_user) form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.url == "": if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.url.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
...@@ -189,6 +210,18 @@ def get_models(user=Depends(get_current_user)): ...@@ -189,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
)
)
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
...@@ -210,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)): ...@@ -210,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
...@@ -224,10 +259,12 @@ class UpdateModelForm(BaseModel): ...@@ -224,10 +259,12 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
app.state.MODEL = model app.state.MODEL = model
return app.state.MODEL return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
...@@ -275,12 +312,31 @@ def save_b64_image(b64_str): ...@@ -275,12 +312,31 @@ def save_b64_image(b64_str):
return None return None
def save_url_image(url):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
r = requests.get(url)
r.raise_for_status()
with open(file_path, "wb") as image_file:
image_file.write(r.content)
return image_id
except Exception as e:
print(f"Error saving image: {e}")
return None
@app.post("/generations") @app.post("/generations")
def generate_image( def generate_image(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
r = None r = None
try: try:
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
...@@ -318,12 +374,47 @@ def generate_image( ...@@ -318,12 +374,47 @@ def generate_image(
return images return images
elif app.state.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
)
print(res)
images = []
for image in res["data"]:
image_id = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
print(images)
return images
else: else:
if form_data.model: if form_data.model:
set_model_handler(form_data.model) set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
"batch_size": form_data.n, "batch_size": form_data.n,
......
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
from pydantic import BaseModel
from typing import Optional
COMFYUI_DEFAULT_PROMPT = """
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def queue_prompt(prompt, client_id, base_url):
print("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, base_url):
print("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
return response.read()
def get_image_url(filename, subfolder, folder_type, base_url):
print("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
print("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
for image in node_output["images"]:
url = get_image_url(
image["filename"], image["subfolder"], image["type"], base_url
)
output_images.append({"url": url})
return {"data": output_images}
class ImageGenerationPayload(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
steps: Optional[int] = None
seed: Optional[int] = None
width: int
height: int
n: int = 1
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
host = base_url.replace("http://", "").replace("https://", "")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
comfyui_prompt["5"]["inputs"]["height"] = payload.height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
if payload.steps:
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
comfyui_prompt["3"]["inputs"]["seed"] = (
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
)
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
print("WebSocket connection established.")
except Exception as e:
print(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
print(f"Error while receiving images: {e}")
images = None
ws.close()
return images
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status from fastapi import (
FastAPI,
Request,
Response,
HTTPException,
Depends,
status,
UploadFile,
File,
BackgroundTasks,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os
import copy
import random import random
import requests import requests
import json import json
...@@ -12,13 +24,17 @@ import uuid ...@@ -12,13 +24,17 @@ import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging import logging
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
from typing import Optional, List, Union
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR
from utils.misc import calculate_sha256
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
...@@ -237,11 +253,26 @@ async def pull_model( ...@@ -237,11 +253,26 @@ async def pull_model(
def get_request(): def get_request():
nonlocal url nonlocal url
nonlocal r nonlocal r
request_id = str(uuid.uuid4())
try: try:
REQUEST_POOL.append(request_id)
def stream_content(): def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk yield chunk
else:
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request( r = requests.request(
method="POST", method="POST",
...@@ -262,6 +293,7 @@ async def pull_model( ...@@ -262,6 +293,7 @@ async def pull_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
...@@ -900,6 +932,211 @@ async def generate_openai_chat_completion( ...@@ -900,6 +932,211 @@ async def generate_openai_chat_completion(
) )
class UrlForm(BaseModel):
url: str
class UploadBlobForm(BaseModel):
filename: str
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(file_url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
):
if url_idx == None:
url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, form_data.url, file_path, file_name),
)
else:
return None
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None:
url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.OLLAMA_BASE_URLS[0]
......
...@@ -114,40 +114,6 @@ class CollectionNameForm(BaseModel): ...@@ -114,40 +114,6 @@ class CollectionNameForm(BaseModel):
class StoreWebForm(CollectionNameForm): class StoreWebForm(CollectionNameForm):
url: str url: str
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
...@@ -329,6 +295,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ...@@ -329,6 +295,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
) )
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
def get_loader(filename: str, file_content_type: str, file_path: str): def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
known_type = True known_type = True
...@@ -464,6 +480,37 @@ def store_doc( ...@@ -464,6 +480,37 @@ def store_doc(
) )
class TextRAGForm(BaseModel):
name: str
content: str
collection_name: Optional[str] = None
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
):
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
form_data.content,
metadata={"name": form_data.name, "created_by": user.id},
collection_name=collection_name,
)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
@app.get("/scan") @app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)): def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"): for path in Path(DOCS_DIR).rglob("./**/*"):
......
...@@ -141,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function): ...@@ -141,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
elif doc["type"] == "text":
context = doc["content"]
else: else:
context = query_doc( context = query_doc(
collection_name=doc["collection_name"], collection_name=doc["collection_name"],
......
...@@ -95,20 +95,6 @@ class ChatTable: ...@@ -95,20 +95,6 @@ class ChatTable:
except: except:
return None return None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id( def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
......
...@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES ...@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
class UploadBlobForm(BaseModel):
filename: str
from urllib.parse import urlparse
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
@router.get("/download")
async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, file_path, file_name),
media_type="text/event-stream",
)
else:
return None
@router.post("/upload")
def upload(file: UploadFile = File(...)):
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
@router.get("/gravatar") @router.get("/gravatar")
async def get_gravatar( async def get_gravatar(
email: str, email: str,
......
...@@ -406,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models" ...@@ -406,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
#################################### ####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
...@@ -45,3 +45,4 @@ PyJWT ...@@ -45,3 +45,4 @@ PyJWT
pyjwt[crypto] pyjwt[crypto]
black black
langfuse
This diff is collapsed.
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.1.114", "version": "0.1.115",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "vite dev --host", "dev": "vite dev --host",
......
...@@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { ...@@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
return res.OPENAI_API_KEY; return res.OPENAI_API_KEY;
}; };
export const getAUTOMATIC1111Url = async (token: string = '') => { export const getImageGenerationEngineUrls = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { const res = await fetch(`${IMAGES_API_BASE_URL}/url`, {
...@@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => { ...@@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => {
throw error; throw error;
} }
return res.AUTOMATIC1111_BASE_URL; return res;
}; };
export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => { export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => {
let error = null; let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, {
...@@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => ...@@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
url: url ...urls
}) })
}) })
.then(async (res) => { .then(async (res) => {
...@@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => ...@@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) =>
throw error; throw error;
} }
return res.AUTOMATIC1111_BASE_URL; return res;
}; };
export const getImageSize = async (token: string = '') => { export const getImageSize = async (token: string = '') => {
......
...@@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) = ...@@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller]; return [res, controller];
}; };
export const cancelChatCompletion = async (token: string = '', requestId: string) => { export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
...@@ -390,6 +390,73 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string | ...@@ -390,6 +390,73 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
return res; return res;
}; };
export const downloadModel = async (
token: string,
download_url: string,
urlIdx: string | null = null
) => {
let error = null;
const res = await fetch(
`${OLLAMA_API_BASE_URL}/models/download${urlIdx !== null ? `/${urlIdx}` : ''}`,
{
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
url: download_url
})
}
).catch((err) => {
console.log(err);
error = err;
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const uploadModel = async (token: string, file: File, urlIdx: string | null = null) => {
let error = null;
const formData = new FormData();
formData.append('file', file);
const res = await fetch(
`${OLLAMA_API_BASE_URL}/models/upload${urlIdx !== null ? `/${urlIdx}` : ''}`,
{
method: 'POST',
headers: {
Authorization: `Bearer ${token}`
},
body: formData
}
).catch((err) => {
console.log(err);
error = err;
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res;
};
// export const pullModel = async (token: string, tagName: string) => { // export const pullModel = async (token: string, tagName: string) => {
// return await fetch(`${OLLAMA_API_BASE_URL}/pull`, { // return await fetch(`${OLLAMA_API_BASE_URL}/pull`, {
// method: 'POST', // method: 'POST',
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
<img <img
src={modelfiles[model]?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`} src={modelfiles[model]?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`}
alt="modelfile" alt="modelfile"
class=" w-14 rounded-full border-[1px] border-gray-200 dark:border-none" class=" size-12 rounded-full border-[1px] border-gray-200 dark:border-none"
draggable="false" draggable="false"
/> />
{:else} {:else}
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
src={models.length === 1 src={models.length === 1
? `${WEBUI_BASE_URL}/static/favicon.png` ? `${WEBUI_BASE_URL}/static/favicon.png`
: `${WEBUI_BASE_URL}/static/favicon.png`} : `${WEBUI_BASE_URL}/static/favicon.png`}
class=" w-14 rounded-full border-[1px] border-gray-200 dark:border-none" class=" size-12 rounded-full border-[1px] border-gray-200 dark:border-none"
alt="logo" alt="logo"
draggable="false" draggable="false"
/> />
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import { models, showSettings, settings, user } from '$lib/stores'; import { models, showSettings, settings, user } from '$lib/stores';
import { onMount, tick, getContext } from 'svelte'; import { onMount, tick, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import Select from '../common/Select.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -32,30 +33,24 @@ ...@@ -32,30 +33,24 @@
} }
</script> </script>
<div class="flex flex-col my-2"> <div class="flex flex-col my-2 w-full">
{#each selectedModels as selectedModel, selectedModelIdx} {#each selectedModels as selectedModel, selectedModelIdx}
<div class="flex"> <div class="flex w-full">
<select <div class="overflow-hidden w-full">
id="models" <div class="mr-2 max-w-full">
class="outline-none bg-transparent text-lg font-semibold rounded-lg block w-full placeholder-gray-400" <Select
placeholder={$i18n.t('Select a model')}
items={$models
.filter((model) => model.name !== 'hr')
.map((model) => ({
value: model.id,
label:
model.name + `${model.size ? ` (${(model.size / 1024 ** 3).toFixed(1)}GB)` : ''}`
}))}
bind:value={selectedModel} bind:value={selectedModel}
{disabled} />
> </div>
<option class=" text-gray-700" value="" selected disabled </div>
>{$i18n.t('Select a model')}</option
>
{#each $models as model}
{#if model.name === 'hr'}
<hr />
{:else}
<option value={model.id} class="text-gray-700 text-lg"
>{model.name +
`${model.size ? ` (${(model.size / 1024 ** 3).toFixed(1)}GB)` : ''}`}</option
>
{/if}
{/each}
</select>
{#if selectedModelIdx === 0} {#if selectedModelIdx === 0}
<button <button
...@@ -136,6 +131,6 @@ ...@@ -136,6 +131,6 @@
{/each} {/each}
</div> </div>
<div class="text-left mt-1.5 text-xs text-gray-500"> <div class="text-left mt-1.5 ml-1 text-xs text-gray-500">
<button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button> <button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button>
</div> </div>
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import fileSaver from 'file-saver'; import fileSaver from 'file-saver';
const { saveAs } = fileSaver; const { saveAs } = fileSaver;
import { resetVectorDB } from '$lib/apis/rag';
import { chats, user } from '$lib/stores'; import { chats, user } from '$lib/stores';
import { import {
...@@ -330,38 +329,6 @@ ...@@ -330,38 +329,6 @@
{$i18n.t('Export All Chats (All Users)')} {$i18n.t('Export All Chats (All Users)')}
</div> </div>
</button> </button>
<hr class=" dark:border-gray-700" />
<button
class=" flex rounded-md py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
on:click={() => {
const res = resetVectorDB(localStorage.token).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success($i18n.t('Success'));
}
}}
>
<div class=" self-center mr-3">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M3.5 2A1.5 1.5 0 0 0 2 3.5v9A1.5 1.5 0 0 0 3.5 14h9a1.5 1.5 0 0 0 1.5-1.5v-7A1.5 1.5 0 0 0 12.5 4H9.621a1.5 1.5 0 0 1-1.06-.44L7.439 2.44A1.5 1.5 0 0 0 6.38 2H3.5Zm6.75 7.75a.75.75 0 0 0 0-1.5h-4.5a.75.75 0 0 0 0 1.5h4.5Z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center text-sm font-medium">{$i18n.t('Reset Vector Storage')}</div>
</button>
{/if} {/if}
</div> </div>
</div> </div>
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
import { createEventDispatcher, onMount, getContext } from 'svelte'; import { createEventDispatcher, onMount, getContext } from 'svelte';
import { config, user } from '$lib/stores'; import { config, user } from '$lib/stores';
import { import {
getAUTOMATIC1111Url,
getImageGenerationModels, getImageGenerationModels,
getDefaultImageGenerationModel, getDefaultImageGenerationModel,
updateDefaultImageGenerationModel, updateDefaultImageGenerationModel,
getImageSize, getImageSize,
getImageGenerationConfig, getImageGenerationConfig,
updateImageGenerationConfig, updateImageGenerationConfig,
updateAUTOMATIC1111Url, getImageGenerationEngineUrls,
updateImageGenerationEngineUrls,
updateImageSize, updateImageSize,
getImageSteps, getImageSteps,
updateImageSteps, updateImageSteps,
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
let enableImageGeneration = false; let enableImageGeneration = false;
let AUTOMATIC1111_BASE_URL = ''; let AUTOMATIC1111_BASE_URL = '';
let COMFYUI_BASE_URL = '';
let OPENAI_API_KEY = ''; let OPENAI_API_KEY = '';
let selectedModel = ''; let selectedModel = '';
...@@ -49,16 +51,38 @@ ...@@ -49,16 +51,38 @@
}); });
}; };
const updateAUTOMATIC1111UrlHandler = async () => { const updateUrlHandler = async () => {
const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch( if (imageGenerationEngine === 'comfyui') {
(error) => { const res = await updateImageGenerationEngineUrls(localStorage.token, {
COMFYUI_BASE_URL: COMFYUI_BASE_URL
}).catch((error) => {
toast.error(error); toast.error(error);
console.log(error);
return null; return null;
});
if (res) {
COMFYUI_BASE_URL = res.COMFYUI_BASE_URL;
await getModels();
if (models) {
toast.success($i18n.t('Server connection verified'));
}
} else {
({ COMFYUI_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
} }
); } else {
const res = await updateImageGenerationEngineUrls(localStorage.token, {
AUTOMATIC1111_BASE_URL: AUTOMATIC1111_BASE_URL
}).catch((error) => {
toast.error(error);
return null;
});
if (res) { if (res) {
AUTOMATIC1111_BASE_URL = res; AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL;
await getModels(); await getModels();
...@@ -66,7 +90,8 @@ ...@@ -66,7 +90,8 @@
toast.success($i18n.t('Server connection verified')); toast.success($i18n.t('Server connection verified'));
} }
} else { } else {
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); ({ AUTOMATIC1111_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token));
}
} }
}; };
const updateImageGeneration = async () => { const updateImageGeneration = async () => {
...@@ -101,7 +126,11 @@ ...@@ -101,7 +126,11 @@
imageGenerationEngine = res.engine; imageGenerationEngine = res.engine;
enableImageGeneration = res.enabled; enableImageGeneration = res.enabled;
} }
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); const URLS = await getImageGenerationEngineUrls(localStorage.token);
AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL;
COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL;
OPENAI_API_KEY = await getOpenAIKey(localStorage.token); OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
imageSize = await getImageSize(localStorage.token); imageSize = await getImageSize(localStorage.token);
...@@ -154,6 +183,7 @@ ...@@ -154,6 +183,7 @@
}} }}
> >
<option value="">{$i18n.t('Default (Automatic1111)')}</option> <option value="">{$i18n.t('Default (Automatic1111)')}</option>
<option value="comfyui">{$i18n.t('ComfyUI')}</option>
<option value="openai">{$i18n.t('Open AI (Dall-E)')}</option> <option value="openai">{$i18n.t('Open AI (Dall-E)')}</option>
</select> </select>
</div> </div>
...@@ -171,6 +201,9 @@ ...@@ -171,6 +201,9 @@
if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
toast.error($i18n.t('AUTOMATIC1111 Base URL is required.')); toast.error($i18n.t('AUTOMATIC1111 Base URL is required.'));
enableImageGeneration = false; enableImageGeneration = false;
} else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') {
toast.error($i18n.t('ComfyUI Base URL is required.'));
enableImageGeneration = false;
} else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') { } else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') {
toast.error($i18n.t('OpenAI API Key is required.')); toast.error($i18n.t('OpenAI API Key is required.'));
enableImageGeneration = false; enableImageGeneration = false;
...@@ -204,12 +237,10 @@ ...@@ -204,12 +237,10 @@
/> />
</div> </div>
<button <button
class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded-lg transition" class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
type="button" type="button"
on:click={() => { on:click={() => {
// updateOllamaAPIUrlHandler(); updateUrlHandler();
updateAUTOMATIC1111UrlHandler();
}} }}
> >
<svg <svg
...@@ -237,6 +268,37 @@ ...@@ -237,6 +268,37 @@
{$i18n.t('(e.g. `sh webui.sh --api`)')} {$i18n.t('(e.g. `sh webui.sh --api`)')}
</a> </a>
</div> </div>
{:else if imageGenerationEngine === 'comfyui'}
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('ComfyUI Base URL')}</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
bind:value={COMFYUI_BASE_URL}
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
type="button"
on:click={() => {
updateUrlHandler();
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
clip-rule="evenodd"
/>
</svg>
</button>
</div>
{:else if imageGenerationEngine === 'openai'} {:else if imageGenerationEngine === 'openai'}
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div> <div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div>
<div class="flex w-full"> <div class="flex w-full">
...@@ -261,6 +323,7 @@ ...@@ -261,6 +323,7 @@
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel} bind:value={selectedModel}
placeholder={$i18n.t('Select a model')} placeholder={$i18n.t('Select a model')}
required
> >
{#if !selectedModel} {#if !selectedModel}
<option value="" disabled selected>{$i18n.t('Select a model')}</option> <option value="" disabled selected>{$i18n.t('Select a model')}</option>
......
...@@ -5,9 +5,12 @@ ...@@ -5,9 +5,12 @@
import { import {
createModel, createModel,
deleteModel, deleteModel,
downloadModel,
getOllamaUrls, getOllamaUrls,
getOllamaVersion, getOllamaVersion,
pullModel pullModel,
cancelOllamaRequest,
uploadModel
} from '$lib/apis/ollama'; } from '$lib/apis/ollama';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, models, user } from '$lib/stores'; import { WEBUI_NAME, models, user } from '$lib/stores';
...@@ -60,11 +63,13 @@ ...@@ -60,11 +63,13 @@
let pullProgress = null; let pullProgress = null;
let modelUploadMode = 'file'; let modelUploadMode = 'file';
let modelInputFile = ''; let modelInputFile: File[] | null = null;
let modelFileUrl = ''; let modelFileUrl = '';
let modelFileContent = `TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop "</s>"\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`; let modelFileContent = `TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop "</s>"\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`;
let modelFileDigest = ''; let modelFileDigest = '';
let uploadProgress = null; let uploadProgress = null;
let uploadMessage = '';
let deleteModelTag = ''; let deleteModelTag = '';
...@@ -159,7 +164,7 @@ ...@@ -159,7 +164,7 @@
// Remove the downloaded model // Remove the downloaded model
delete modelDownloadStatus[modelName]; delete modelDownloadStatus[modelName];
console.log(data); modelDownloadStatus = { ...modelDownloadStatus };
if (!data.success) { if (!data.success) {
toast.error(data.error); toast.error(data.error);
...@@ -184,35 +189,32 @@ ...@@ -184,35 +189,32 @@
const uploadModelHandler = async () => { const uploadModelHandler = async () => {
modelTransferring = true; modelTransferring = true;
uploadProgress = 0;
let uploaded = false; let uploaded = false;
let fileResponse = null; let fileResponse = null;
let name = ''; let name = '';
if (modelUploadMode === 'file') { if (modelUploadMode === 'file') {
const file = modelInputFile[0]; const file = modelInputFile ? modelInputFile[0] : null;
const formData = new FormData();
formData.append('file', file); if (file) {
uploadMessage = 'Uploading...';
fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/upload`, {
method: 'POST', fileResponse = await uploadModel(localStorage.token, file, selectedOllamaUrlIdx).catch(
headers: { (error) => {
...($user && { Authorization: `Bearer ${localStorage.token}` }) toast.error(error);
},
body: formData
}).catch((error) => {
console.log(error);
return null; return null;
});
} else {
fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/download?url=${modelFileUrl}`, {
method: 'GET',
headers: {
...($user && { Authorization: `Bearer ${localStorage.token}` })
} }
}).catch((error) => { );
console.log(error); }
} else {
uploadProgress = 0;
fileResponse = await downloadModel(
localStorage.token,
modelFileUrl,
selectedOllamaUrlIdx
).catch((error) => {
toast.error(error);
return null; return null;
}); });
} }
...@@ -235,6 +237,9 @@ ...@@ -235,6 +237,9 @@
let data = JSON.parse(line.replace(/^data: /, '')); let data = JSON.parse(line.replace(/^data: /, ''));
if (data.progress) { if (data.progress) {
if (uploadMessage) {
uploadMessage = '';
}
uploadProgress = data.progress; uploadProgress = data.progress;
} }
...@@ -318,7 +323,11 @@ ...@@ -318,7 +323,11 @@
} }
modelFileUrl = ''; modelFileUrl = '';
modelInputFile = '';
if (modelUploadInputElement) {
modelUploadInputElement.value = '';
}
modelInputFile = null;
modelTransferring = false; modelTransferring = false;
uploadProgress = null; uploadProgress = null;
...@@ -364,12 +373,24 @@ ...@@ -364,12 +373,24 @@
for (const line of lines) { for (const line of lines) {
if (line !== '') { if (line !== '') {
let data = JSON.parse(line); let data = JSON.parse(line);
console.log(data);
if (data.error) { if (data.error) {
throw data.error; throw data.error;
} }
if (data.detail) { if (data.detail) {
throw data.detail; throw data.detail;
} }
if (data.id) {
modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
requestId: data.id,
reader,
done: false
};
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
...@@ -379,11 +400,17 @@ ...@@ -379,11 +400,17 @@
downloadProgress = 100; downloadProgress = 100;
} }
modelDownloadStatus[opts.modelName] = { modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
pullProgress: downloadProgress, pullProgress: downloadProgress,
digest: data.digest digest: data.digest
}; };
} else { } else {
toast.success(data.status); toast.success(data.status);
modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
done: data.status === 'success'
};
} }
} }
} }
...@@ -396,7 +423,14 @@ ...@@ -396,7 +423,14 @@
opts.callback({ success: false, error, modelName: opts.modelName }); opts.callback({ success: false, error, modelName: opts.modelName });
} }
} }
console.log(modelDownloadStatus[opts.modelName]);
if (modelDownloadStatus[opts.modelName].done) {
opts.callback({ success: true, modelName: opts.modelName }); opts.callback({ success: true, modelName: opts.modelName });
} else {
opts.callback({ success: false, error: 'Download canceled', modelName: opts.modelName });
}
} }
}; };
...@@ -466,6 +500,18 @@ ...@@ -466,6 +500,18 @@
ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false); ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
}); });
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = modelDownloadStatus[model];
if (reader) {
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete modelDownloadStatus[model];
await deleteModel(localStorage.token, model);
toast.success(`${model} download has been canceled`);
}
};
</script> </script>
<div class="flex flex-col h-full justify-between text-sm"> <div class="flex flex-col h-full justify-between text-sm">
...@@ -596,20 +642,58 @@ ...@@ -596,20 +642,58 @@
{#if Object.keys(modelDownloadStatus).length > 0} {#if Object.keys(modelDownloadStatus).length > 0}
{#each Object.keys(modelDownloadStatus) as model} {#each Object.keys(modelDownloadStatus) as model}
{#if 'pullProgress' in modelDownloadStatus[model]}
<div class="flex flex-col"> <div class="flex flex-col">
<div class="font-medium mb-1">{model}</div> <div class="font-medium mb-1">{model}</div>
<div class=""> <div class="">
<div class="flex flex-row justify-between space-x-4 pr-2">
<div class=" flex-1">
<div <div
class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full" class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
style="width: {Math.max(15, modelDownloadStatus[model].pullProgress ?? 0)}%" style="width: {Math.max(
15,
modelDownloadStatus[model].pullProgress ?? 0
)}%"
> >
{modelDownloadStatus[model].pullProgress ?? 0}% {modelDownloadStatus[model].pullProgress ?? 0}%
</div> </div>
</div>
<Tooltip content="Cancel">
<button
class="text-gray-800 dark:text-gray-100"
on:click={() => {
cancelModelPullHandler(model);
}}
>
<svg
class="w-4 h-4 text-gray-800 dark:text-white"
aria-hidden="true"
xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
fill="currentColor"
viewBox="0 0 24 24"
>
<path
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18 17.94 6M18 18 6.06 6"
/>
</svg>
</button>
</Tooltip>
</div>
{#if 'digest' in modelDownloadStatus[model]}
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;"> <div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
{modelDownloadStatus[model].digest} {modelDownloadStatus[model].digest}
</div> </div>
{/if}
</div> </div>
</div> </div>
{/if}
{/each} {/each}
{/if} {/if}
</div> </div>
...@@ -715,7 +799,7 @@ ...@@ -715,7 +799,7 @@
<button <button
type="button" type="button"
class="w-full rounded-lg text-left py-2 px-4 dark:text-gray-300 dark:bg-gray-850" class="w-full rounded-lg text-left py-2 px-4 bg-white dark:text-gray-300 dark:bg-gray-850"
on:click={() => { on:click={() => {
modelUploadInputElement.click(); modelUploadInputElement.click();
}} }}
...@@ -730,7 +814,7 @@ ...@@ -730,7 +814,7 @@
{:else} {:else}
<div class="flex-1 {modelFileUrl !== '' ? 'mr-2' : ''}"> <div class="flex-1 {modelFileUrl !== '' ? 'mr-2' : ''}">
<input <input
class="w-full rounded-lg text-left py-2 px-4 dark:text-gray-300 dark:bg-gray-850 outline-none {modelFileUrl !== class="w-full rounded-lg text-left py-2 px-4 bg-white dark:text-gray-300 dark:bg-gray-850 outline-none {modelFileUrl !==
'' ''
? 'mr-2' ? 'mr-2'
: ''}" : ''}"
...@@ -745,7 +829,7 @@ ...@@ -745,7 +829,7 @@
{#if (modelUploadMode === 'file' && modelInputFile && modelInputFile.length > 0) || (modelUploadMode === 'url' && modelFileUrl !== '')} {#if (modelUploadMode === 'file' && modelInputFile && modelInputFile.length > 0) || (modelUploadMode === 'url' && modelFileUrl !== '')}
<button <button
class="px-3 text-gray-100 bg-emerald-600 hover:bg-emerald-700 disabled:bg-gray-700 disabled:cursor-not-allowed rounded transition" class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg disabled:cursor-not-allowed transition"
type="submit" type="submit"
disabled={modelTransferring} disabled={modelTransferring}
> >
...@@ -800,7 +884,7 @@ ...@@ -800,7 +884,7 @@
<div class=" my-2.5 text-sm font-medium">{$i18n.t('Modelfile Content')}</div> <div class=" my-2.5 text-sm font-medium">{$i18n.t('Modelfile Content')}</div>
<textarea <textarea
bind:value={modelFileContent} bind:value={modelFileContent}
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none" class="w-full rounded-lg py-2 px-4 text-sm bg-gray-100 dark:text-gray-100 dark:bg-gray-850 outline-none resize-none"
rows="6" rows="6"
/> />
</div> </div>
...@@ -815,7 +899,23 @@ ...@@ -815,7 +899,23 @@
> >
</div> </div>
{#if uploadProgress !== null} {#if uploadMessage}
<div class="mt-2">
<div class=" mb-2 text-xs">{$i18n.t('Upload Progress')}</div>
<div class="w-full rounded-full dark:bg-gray-800">
<div
class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
style="width: 100%"
>
{uploadMessage}
</div>
</div>
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
{modelFileDigest}
</div>
</div>
{:else if uploadProgress !== null}
<div class="mt-2"> <div class="mt-2">
<div class=" mb-2 text-xs">{$i18n.t('Upload Progress')}</div> <div class=" mb-2 text-xs">{$i18n.t('Upload Progress')}</div>
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import { DropdownMenu } from 'bits-ui'; import { DropdownMenu } from 'bits-ui';
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
import { flyAndScale } from '$lib/utils/transitions';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
</script> </script>
...@@ -20,6 +22,7 @@ ...@@ -20,6 +22,7 @@
sideOffset={8} sideOffset={8}
side="bottom" side="bottom"
align="start" align="start"
transition={flyAndScale}
> >
<DropdownMenu.Item class="flex items-center px-3 py-2 text-sm font-medium"> <DropdownMenu.Item class="flex items-center px-3 py-2 text-sm font-medium">
<div class="flex items-center">Profile</div> <div class="flex items-center">Profile</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment