Unverified Commit 92d9b381 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge branch 'dev' into feat/openai-embeddings-batch

parents 0cb81633 36a66fcf
......@@ -70,8 +70,10 @@ jobs:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-${{ matrix.platform }}-
latest=false
- name: Build Docker image (latest)
uses: docker/build-push-action@v5
......@@ -158,8 +160,10 @@ jobs:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-cuda-${{ matrix.platform }}-
latest=false
- name: Build Docker image (cuda)
uses: docker/build-push-action@v5
......@@ -247,8 +251,10 @@ jobs:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-ollama-${{ matrix.platform }}-
latest=false
- name: Build Docker image (ollama)
uses: docker/build-push-action@v5
......
......@@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.2.2] - 2024-06-02
### Added
- **🌊 Mermaid Rendering Support**: We've included support for Mermaid rendering. This allows you to create beautiful diagrams and flowcharts directly within Open WebUI.
- **🔄 New Environment Variable 'RESET_CONFIG_ON_START'**: Introducing a new environment variable: 'RESET_CONFIG_ON_START'. Set this variable to reset your configuration settings upon starting the application, making it easier to revert to default settings.
### Fixed
- **🔧 Pipelines Filter Issue**: We've addressed an issue with the pipelines where filters were not functioning as expected.
## [0.2.1] - 2024-06-02
### Added
- **🖱️ Single Model Export Button**: Easily export models with just one click using the new single model export button.
- **🖥️ Advanced Parameters Support**: Added support for 'num_thread', 'use_mmap', and 'use_mlock' parameters for Ollama.
- **🌐 Improved Vietnamese Translation**: Enhanced Vietnamese language support for a better user experience for our Vietnamese-speaking community.
### Fixed
- **🔧 OpenAI URL API Save Issue**: Corrected a problem preventing the saving of OpenAI URL API settings.
- **🚫 Display Issue with Disabled Ollama API**: Fixed the display bug causing models to appear in settings when the Ollama API was disabled.
### Changed
- **💡 Versioning Update**: As a reminder from our previous update, version 0.2.y will focus primarily on bug fixes, while major updates will be designated as 0.x from now on for better version tracking.
## [0.2.0] - 2024-06-01
### Added
......
......@@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
......@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.
......@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5)
try:
......@@ -154,6 +143,45 @@ async def fetch_url(url):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession()
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists):
merged_models = {}
......@@ -313,65 +341,7 @@ async def pull_model(
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request():
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
class PushModelForm(BaseModel):
......@@ -399,50 +369,9 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
)
class CreateModelForm(BaseModel):
......@@ -461,53 +390,9 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
)
class CopyModelForm(BaseModel):
......@@ -797,66 +682,9 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
)
class ChatMessage(BaseModel):
......@@ -906,44 +734,77 @@ async def generate_chat_completion(
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("mirostat", None):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("mirostat_eta", None):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("mirostat_tau", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
if model_info.params.get("frequency_penalty", None):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("temperature", None):
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if model_info.params.get("seed", None):
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("tfs_z", None):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("max_tokens", None):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if model_info.params.get("top_k", None):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("top_p", None):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("use_mmap", None):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if model_info.params.get("use_mlock", None):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if model_info.params.get("num_thread", None):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
......@@ -981,67 +842,7 @@ async def generate_chat_completion(
print(payload)
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
# TODO: we should update this part once Ollama supports other types
......@@ -1132,68 +933,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
@app.get("/v1/models")
......@@ -1522,7 +1262,7 @@ async def deprecated_proxy(
if path == "generate":
data = json.loads(body.decode("utf-8"))
if not ("stream" in data and data["stream"] == False):
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":
......
......@@ -9,6 +9,7 @@ import json
import logging
from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
......@@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
......@@ -228,6 +239,27 @@ async def get_all_models(raw: bool = False):
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
......@@ -426,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
r = requests.request(
session = aiohttp.ClientSession()
r = await session.request(
method=request.method,
url=target_url,
data=payload if payload else body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = r.json()
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
import logging
import requests
from typing import List
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
......@@ -9,20 +10,52 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(query_url: str, query: str, count: int) -> list[SearchResult]:
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.
def search_searxng(query_url: str, query: str, count: int, **kwargs) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
The function allows passing additional parameters such as language or time_range to tailor the search result.
Args:
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
query (str): The query to search for
query_url (str): The base URL of the SearXNG server with a placeholder for the query "<query>".
query (str): The search term or question to find in the SearXNG database.
count (int): The maximum number of results to retrieve from the search.
Keyword Args:
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
url = query_url.replace("<query>", query)
if "&format=json" not in url:
url += "&format=json"
log.debug(f"searching {url}")
# Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get('language', 'en-US')
time_range = kwargs.get('time_range', '')
categories = ''.join(kwargs.get('categories', []))
params = {
"q": query,
"format": "json",
"pageno": 1,
"results_per_page": count,
'language': language,
'time_range': time_range,
'engines': '',
'categories': categories,
'theme': 'simple',
'image_proxy': 0
r = requests.get(
url,
}
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
......@@ -30,15 +63,17 @@ def search_searxng(query_url: str, query: str, count: int) -> list[SearchResult]
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
r.raise_for_status()
json_response = r.json()
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
for result in sorted_results[:count]
for result in sorted_results
]
......@@ -298,6 +298,15 @@ class ChatTable:
# .limit(limit).offset(skip)
]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
......
......@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
]
############################
# GetArchivedChats
############################
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllChatsInDB
############################
......
......@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except:
......
......@@ -498,10 +498,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id]
print(model_id)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
......@@ -550,7 +552,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)
print(responses)
urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response]
urlIdxs = [
idx
for idx, response in enumerate(responses)
if response != None and "pipelines" in response
]
return {
"data": [
......
This diff is collapsed.
{
"name": "open-webui",
"version": "0.2.0",
"version": "0.2.2",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
......@@ -63,6 +63,7 @@
"js-sha256": "^0.10.1",
"katex": "^0.16.9",
"marked": "^9.1.0",
"mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4",
"sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19",
......
......@@ -162,6 +162,37 @@ export const getAllChats = async (token: string) => {
return res;
};
export const getAllArchivedChats = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/archived`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAllUserChats = async (token: string) => {
let error = null;
......
......@@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller];
};
export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
method: 'GET',
headers: {
'Content-Type': 'text/event-stream',
Authorization: `Bearer ${token}`
}
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const createModel = async (token: string, tagName: string, content: string) => {
let error = null;
......@@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
let error = null;
const controller = new AbortController();
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
signal: controller.signal,
method: 'POST',
headers: {
Accept: 'application/json',
......@@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
if (error) {
throw error;
}
return res;
return [res, controller];
};
export const downloadModel = async (
......
<script lang="ts">
import { v4 as uuidv4 } from 'uuid';
import { toast } from 'svelte-sonner';
import mermaid from 'mermaid';
import { getContext, onMount, tick } from 'svelte';
import { goto } from '$app/navigation';
......@@ -26,7 +27,7 @@
splitStream
} from '$lib/utils';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
import { generateChatCompletion } from '$lib/apis/ollama';
import {
addTagById,
createNewChat,
......@@ -65,7 +66,6 @@
let autoScroll = true;
let processing = '';
let messagesContainerElement: HTMLDivElement;
let currentRequestId = null;
let showModelSelector = true;
......@@ -130,10 +130,6 @@
//////////////////////////
const initNewChat = async () => {
if (currentRequestId !== null) {
await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null;
}
window.history.replaceState(history.state, '', `/`);
await chatId.set('');
......@@ -251,6 +247,39 @@
}
};
const chatCompletedHandler = async (modelId, messages) => {
await mermaid.run({
querySelector: '.mermaid'
});
const res = await chatCompleted(localStorage.token, {
model: modelId,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
};
//////////////////////////
// Ollama functions
//////////////////////////
......@@ -616,39 +645,11 @@
if (stopResponseFlag) {
controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else {
const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, {
model: model,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
await chatCompletedHandler(model, messages);
}
currentRequestId = null;
break;
}
......@@ -669,63 +670,58 @@
throw data;
}
if ('id' in data) {
console.log(data);
currentRequestId = data.id;
} else {
if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.done = true;
if (responseMessage.content == '') {
responseMessage.error = {
code: 400,
content: `Oops! No text generated from Ollama, Please try again.`
};
}
responseMessage.context = data.context ?? null;
responseMessage.info = {
total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
responseMessage.content += data.message.content;
messages = messages;
}
} else {
responseMessage.done = true;
if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification(
selectedModelfile
? `${
selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1)
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
if ($settings.responseAutoPlayback) {
await tick();
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
}
if (responseMessage.content == '') {
responseMessage.error = {
code: 400,
content: `Oops! No text generated from Ollama, Please try again.`
};
}
responseMessage.context = data.context ?? null;
responseMessage.info = {
total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages;
if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification(
selectedModelfile
? `${
selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1)
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
if ($settings.responseAutoPlayback) {
await tick();
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
}
}
}
......@@ -906,32 +902,7 @@
} else {
const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, {
model: model.id,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
await chatCompletedHandler(model.id, messages);
}
break;
......
......@@ -810,10 +810,7 @@
? $i18n.t('Listening...')
: $i18n.t('Send a Message')}
bind:value={prompt}
on:keypress={(e) => {}}
on:keydown={async (e) => {
// Check if the device is not a mobile device or if it is a mobile device, check if it is not a touch device
// This is to prevent the Enter key from submitting the prompt on mobile devices
on:keypress={(e) => {
if (
!$mobile ||
!(
......@@ -822,22 +819,18 @@
navigator.msMaxTouchPoints > 0
)
) {
// Check if Enter is pressed
// Check if Shift key is not pressed
// Prevent Enter key from creating a new line
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
}
if (e.key === 'Enter' && !e.shiftKey && prompt !== '') {
// Submit the prompt when Enter key is pressed
if (prompt !== '' && e.key === 'Enter' && !e.shiftKey) {
submitPrompt(prompt, user);
return;
}
if (e.key === 'Enter' && e.shiftKey && prompt !== '') {
return;
}
}
}}
on:keydown={async (e) => {
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
// Check if Ctrl + R is pressed
......@@ -898,7 +891,9 @@
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
if (commandOptionButton) {
if (e.shiftKey) {
prompt = `${prompt}\n`;
} else if (commandOptionButton) {
commandOptionButton?.click();
} else {
document.getElementById('send-message-button')?.click();
......
<script lang="ts">
import { v4 as uuidv4 } from 'uuid';
import { chats, config, settings, user as _user, mobile } from '$lib/stores';
import { tick, getContext } from 'svelte';
import { tick, getContext, onMount } from 'svelte';
import { toast } from 'svelte-sonner';
import { getChatList, updateChatById } from '$lib/apis/chats';
......
......@@ -5,6 +5,7 @@
import tippy from 'tippy.js';
import auto_render from 'katex/dist/contrib/auto-render.mjs';
import 'katex/dist/katex.min.css';
import mermaid from 'mermaid';
import { fade } from 'svelte/transition';
import { createEventDispatcher } from 'svelte';
......@@ -340,9 +341,24 @@
generatingImage = false;
};
$: if (!edit) {
(async () => {
await tick();
renderStyling();
await mermaid.run({
querySelector: '.mermaid'
});
})();
}
onMount(async () => {
await tick();
renderStyling();
await mermaid.run({
querySelector: '.mermaid'
});
});
</script>
......@@ -458,11 +474,15 @@
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
{#each tokens as token, tokenIdx}
{#if token.type === 'code'}
<CodeBlock
id={`${message.id}-${tokenIdx}`}
lang={token?.lang ?? ''}
code={revertSanitizedResponseContent(token?.text ?? '')}
/>
{#if token.lang === 'mermaid'}
<pre class="mermaid">{revertSanitizedResponseContent(token.text)}</pre>
{:else}
<CodeBlock
id={`${message.id}-${tokenIdx}`}
lang={token?.lang ?? ''}
code={revertSanitizedResponseContent(token?.text ?? '')}
/>
{/if}
{:else}
{@html marked.parse(token.raw, {
...defaults,
......
......@@ -196,7 +196,7 @@
<div class=" mt-2 mb-1 flex justify-end space-x-1.5 text-sm font-medium">
<button
id="close-edit-message-button"
class="px-4 py-2 bg-white hover:bg-gray-100 text-gray-800 transition rounded-3xl"
class="px-4 py-2 bg-white dark:bg-gray-900 hover:bg-gray-100 text-gray-800 dark:text-gray-100 transition rounded-3xl"
on:click={() => {
cancelEditMessage();
}}
......@@ -206,7 +206,7 @@
<button
id="save-edit-message-button"
class=" px-4 py-2 bg-gray-900 hover:bg-gray-850 text-gray-100 transition rounded-3xl"
class=" px-4 py-2 bg-gray-900 dark:bg-white hover:bg-gray-850 text-gray-100 dark:text-gray-800 transition rounded-3xl"
on:click={() => {
editMessageConfirmHandler();
}}
......
......@@ -8,7 +8,7 @@
import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte';
import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { toast } from 'svelte-sonner';
......@@ -72,10 +72,12 @@
return;
}
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
toast.error(error);
return null;
});
const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
(error) => {
toast.error(error);
return null;
}
);
if (res) {
const reader = res.body
......@@ -83,6 +85,16 @@
.pipeThrough(splitStream('\n'))
.getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) {
try {
const { value, done } = await reader.read();
......@@ -101,19 +113,6 @@
throw data.detail;
}
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) {
if (data.digest) {
let downloadProgress = 0;
......@@ -181,11 +180,12 @@
});
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) {
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL
......
......@@ -20,6 +20,9 @@
tfs_z: '',
num_ctx: '',
max_tokens: '',
use_mmap: null,
use_mlock: null,
num_thread: null,
template: null
};
......@@ -559,6 +562,7 @@
</div>
{/if}
</div>
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Max Tokens (num_predict)')}</div>
......@@ -604,6 +608,93 @@
</div>
{/if}
</div>
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('use_mmap (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
type="button"
on:click={() => {
params.use_mmap = (params?.use_mmap ?? null) === null ? true : null;
}}
>
{#if (params?.use_mmap ?? null) === null}
<span class="ml-2 self-center">{$i18n.t('Default')}</span>
{:else}
<span class="ml-2 self-center">{$i18n.t('On')}</span>
{/if}
</button>
</div>
</div>
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('use_mlock (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
type="button"
on:click={() => {
params.use_mlock = (params?.use_mlock ?? null) === null ? true : null;
}}
>
{#if (params?.use_mlock ?? null) === null}
<span class="ml-2 self-center">{$i18n.t('Default')}</span>
{:else}
<span class="ml-2 self-center">{$i18n.t('On')}</span>
{/if}
</button>
</div>
</div>
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('num_thread (Ollama)')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
type="button"
on:click={() => {
params.num_thread = (params?.num_thread ?? null) === null ? 2 : null;
}}
>
{#if (params?.num_thread ?? null) === null}
<span class="ml-2 self-center">{$i18n.t('Default')}</span>
{:else}
<span class="ml-2 self-center">{$i18n.t('Custom')}</span>
{/if}
</button>
</div>
{#if (params?.num_thread ?? null) !== null}
<div class="flex mt-0.5 space-x-2">
<div class=" flex-1">
<input
id="steps-range"
type="range"
min="1"
max="256"
step="1"
bind:value={params.num_thread}
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
/>
</div>
<div class="">
<input
bind:value={params.num_thread}
type="number"
class=" bg-transparent text-center w-14"
min="1"
max="256"
step="1"
/>
</div>
</div>
{/if}
</div>
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Template')}</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment