Unverified Commit a0935dec authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge branch 'dev' into feature/support_auth_by_api_key

parents b4b56f9c 587a8c59
name: Python CI name: Python CI
on: on:
push: push:
branches: ['main'] branches:
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format Backend' name: 'Format Backend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
node-version: python-version: [3.11]
- latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Python
uses: actions/setup-python@v4 - name: Set up Python
- name: Use Bun uses: actions/setup-python@v2
uses: oven-sh/setup-bun@v1 with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install yapf pip install black
- name: Format backend - name: Format backend
run: bun run format:backend run: black . --exclude "/venv/"
- name: Check for changes after format
run: git diff --exit-code
name: Bun CI name: Frontend Build
on: on:
push: push:
branches: ['main'] branches:
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format & Build Frontend' name: 'Format & Build Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - name: Checkout Repository
- name: Use Bun uses: actions/checkout@v4
uses: oven-sh/setup-bun@v1
- run: bun --version - name: Setup Node.js
- name: Install frontend dependencies uses: actions/setup-node@v3
run: bun install with:
- name: Format frontend node-version: '20' # Or specify any other version you want to use
run: bun run format
- name: Build frontend - name: Install Dependencies
run: bun run build run: npm install
- name: Format Frontend
run: npm run format
- name: Check for Changes After Format
run: git diff --exit-code
- name: Build Frontend
run: npm run build
...@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.116] - 2024-03-31
### Added
- **🔄 Enhanced UI**: Model selector now conveniently located in the navbar, enabling seamless switching between multiple models during conversations.
- **🔍 Improved Model Selector**: Directly pull a model from the selector/Models now display detailed information for better understanding.
- **💬 Webhook Support**: Now compatible with Google Chat and Microsoft Teams.
- **🌐 Localization**: Korean translation (I18n) now available.
- **🌑 Dark Theme**: OLED dark theme introduced for reduced strain during prolonged usage.
- **🏷️ Tag Autocomplete**: Dropdown feature added for effortless chat tagging.
### Fixed
- **🔽 Auto-Scrolling**: Addressed OpenAI auto-scrolling issue.
- **🏷️ Tag Validation**: Implemented tag validation to prevent empty string tags.
- **🚫 Model Whitelisting**: Resolved LiteLLM model whitelisting issue.
- **✅ Spelling**: Corrected various spelling issues for improved readability.
## [0.1.115] - 2024-03-24 ## [0.1.115] - 2024-03-24
### Added ### Added
......
...@@ -22,7 +22,13 @@ from utils.utils import ( ...@@ -22,7 +22,13 @@ from utils.utils import (
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"]) log.setLevel(SRC_LOG_LEVELS["AUDIO"])
......
...@@ -325,7 +325,7 @@ def save_url_image(url): ...@@ -325,7 +325,7 @@ def save_url_image(url):
return image_id return image_id
except Exception as e: except Exception as e:
print(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None
...@@ -397,7 +397,7 @@ def generate_image( ...@@ -397,7 +397,7 @@ def generate_image(
user.id, user.id,
app.state.COMFYUI_BASE_URL, app.state.COMFYUI_BASE_URL,
) )
print(res) log.debug(f"res: {res}")
images = [] images = []
...@@ -409,7 +409,7 @@ def generate_image( ...@@ -409,7 +409,7 @@ def generate_image(
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f) json.dump(data.model_dump(exclude_none=True), f)
print(images) log.debug(f"images: {images}")
return images return images
else: else:
if form_data.model: if form_data.model:
......
...@@ -4,6 +4,12 @@ import json ...@@ -4,6 +4,12 @@ import json
import urllib.request import urllib.request
import urllib.parse import urllib.parse
import random import random
import logging
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel from pydantic import BaseModel
...@@ -121,7 +127,7 @@ COMFYUI_DEFAULT_PROMPT = """ ...@@ -121,7 +127,7 @@ COMFYUI_DEFAULT_PROMPT = """
def queue_prompt(prompt, client_id, base_url): def queue_prompt(prompt, client_id, base_url):
print("queue_prompt") log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id} p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8") data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data) req = urllib.request.Request(f"{base_url}/prompt", data=data)
...@@ -129,7 +135,7 @@ def queue_prompt(prompt, client_id, base_url): ...@@ -129,7 +135,7 @@ def queue_prompt(prompt, client_id, base_url):
def get_image(filename, subfolder, folder_type, base_url): def get_image(filename, subfolder, folder_type, base_url):
print("get_image") log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data) url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
...@@ -137,14 +143,14 @@ def get_image(filename, subfolder, folder_type, base_url): ...@@ -137,14 +143,14 @@ def get_image(filename, subfolder, folder_type, base_url):
def get_image_url(filename, subfolder, folder_type, base_url): def get_image_url(filename, subfolder, folder_type, base_url):
print("get_image") log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data) url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}" return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url): def get_history(prompt_id, base_url):
print("get_history") log.info("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read()) return json.loads(response.read())
...@@ -212,15 +218,15 @@ def comfyui_generate_image( ...@@ -212,15 +218,15 @@ def comfyui_generate_image(
try: try:
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}") ws.connect(f"ws://{host}/ws?clientId={client_id}")
print("WebSocket connection established.") log.info("WebSocket connection established.")
except Exception as e: except Exception as e:
print(f"Failed to connect to WebSocket server: {e}") log.exception(f"Failed to connect to WebSocket server: {e}")
return None return None
try: try:
images = get_images(ws, comfyui_prompt, client_id, base_url) images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e: except Exception as e:
print(f"Error while receiving images: {e}") log.exception(f"Error while receiving images: {e}")
images = None images = None
ws.close() ws.close()
......
...@@ -33,7 +33,13 @@ from constants import ERROR_MESSAGES ...@@ -33,7 +33,13 @@ from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
UPLOAD_DIR,
)
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -266,7 +272,7 @@ async def pull_model( ...@@ -266,7 +272,7 @@ async def pull_model(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
...@@ -664,7 +670,7 @@ async def generate_completion( ...@@ -664,7 +670,7 @@ async def generate_completion(
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="error_detail", detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
...@@ -770,7 +776,11 @@ async def generate_chat_completion( ...@@ -770,7 +776,11 @@ async def generate_chat_completion(
r = None r = None
log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode())) log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request(): def get_request():
nonlocal form_data nonlocal form_data
......
...@@ -21,6 +21,7 @@ from langchain_community.document_loaders import ( ...@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
TextLoader, TextLoader,
PyPDFLoader, PyPDFLoader,
CSVLoader, CSVLoader,
BSHTMLLoader,
Docx2txtLoader, Docx2txtLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredWordDocumentLoader, UnstructuredWordDocumentLoader,
...@@ -114,6 +115,7 @@ class CollectionNameForm(BaseModel): ...@@ -114,6 +115,7 @@ class CollectionNameForm(BaseModel):
class StoreWebForm(CollectionNameForm): class StoreWebForm(CollectionNameForm):
url: str url: str
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
...@@ -296,13 +298,18 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ...@@ -296,13 +298,18 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
return store_docs_in_vector_db(docs, collection_name, overwrite)
if len(docs) > 0:
return store_docs_in_vector_db(docs, collection_name, overwrite), None
else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
def store_text_in_vector_db( def store_text_in_vector_db(
...@@ -318,6 +325,7 @@ def store_text_in_vector_db( ...@@ -318,6 +325,7 @@ def store_text_in_vector_db(
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
...@@ -325,7 +333,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -325,7 +333,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
if overwrite: if overwrite:
for collection in CHROMA_CLIENT.list_collections(): for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name: if collection_name == collection.name:
print(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection( collection = CHROMA_CLIENT.create_collection(
...@@ -338,7 +346,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -338,7 +346,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
) )
return True return True
except Exception as e: except Exception as e:
print(e) log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError": if e.__class__.__name__ == "UniqueConstraintError":
return True return True
...@@ -402,6 +410,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -402,6 +410,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
loader = UnstructuredRSTLoader(file_path, mode="elements") loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml": elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path) loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md": elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path) loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip": elif file_content_type == "application/epub+zip":
...@@ -452,19 +462,21 @@ def store_doc( ...@@ -452,19 +462,21 @@ def store_doc(
loader, known_type = get_loader(file.filename, file.content_type, file_path) loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name)
try:
if result: result = store_data_in_vector_db(data, collection_name)
return {
"status": True, if result:
"collection_name": collection_name, return {
"filename": filename, "status": True,
"known_type": known_type, "collection_name": collection_name,
} "filename": filename,
else: "known_type": known_type,
}
except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(), detail=e,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -529,38 +541,42 @@ def scan_docs_dir(user=Depends(get_admin_user)): ...@@ -529,38 +541,42 @@ def scan_docs_dir(user=Depends(get_admin_user)):
) )
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name) try:
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename) if result:
doc = Documents.get_doc_by_name(sanitized_filename) sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc( if doc == None:
user.id, doc = Documents.insert_new_doc(
DocumentForm( user.id,
**{ DocumentForm(
"name": sanitized_filename, **{
"title": filename, "name": sanitized_filename,
"collection_name": collection_name, "title": filename,
"filename": filename, "collection_name": collection_name,
"content": ( "filename": filename,
json.dumps( "content": (
{ json.dumps(
"tags": list( {
map( "tags": list(
lambda name: {"name": name}, map(
tags, lambda name: {"name": name},
tags,
)
) )
) }
} )
) if len(tags)
if len(tags) else "{}"
else "{}" ),
), }
} ),
), )
) except Exception as e:
log.exception(e)
pass
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
......
...@@ -11,6 +11,7 @@ from utils.utils import verify_password ...@@ -11,6 +11,7 @@ from utils.utils import verify_password
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
......
...@@ -13,6 +13,7 @@ from apps.web.internal.db import DB ...@@ -13,6 +13,7 @@ from apps.web.internal.db import DB
import json import json
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
......
...@@ -64,8 +64,8 @@ class ModelfilesTable: ...@@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile]) self.db.create_tables([Modelfile])
def insert_new_modelfile( def insert_new_modelfile(
self, user_id: str, self, user_id: str, form_data: ModelfileForm
form_data: ModelfileForm) -> Optional[ModelfileModel]: ) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile: if "tagName" in form_data.modelfile:
modelfile = ModelfileModel( modelfile = ModelfileModel(
**{ **{
...@@ -73,7 +73,8 @@ class ModelfilesTable: ...@@ -73,7 +73,8 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"], "tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile), "modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()), "timestamp": int(time.time()),
}) }
)
try: try:
result = Modelfile.create(**modelfile.model_dump()) result = Modelfile.create(**modelfile.model_dump())
...@@ -87,29 +88,28 @@ class ModelfilesTable: ...@@ -87,29 +88,28 @@ class ModelfilesTable:
else: else:
return None return None
def get_modelfile_by_tag_name(self, def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
tag_name: str) -> Optional[ModelfileModel]:
try: try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name) modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile)) return ModelfileModel(**model_to_dict(modelfile))
except: except:
return None return None
def get_modelfiles(self, def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [ return [
ModelfileResponse( ModelfileResponse(
**{ **{
**model_to_dict(modelfile), **model_to_dict(modelfile),
"modelfile": "modelfile": json.loads(modelfile.modelfile),
json.loads(modelfile.modelfile), }
}) for modelfile in Modelfile.select() )
for modelfile in Modelfile.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_modelfile_by_tag_name( def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]: self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
try: try:
query = Modelfile.update( query = Modelfile.update(
modelfile=json.dumps(modelfile), modelfile=json.dumps(modelfile),
......
...@@ -52,8 +52,9 @@ class PromptsTable: ...@@ -52,8 +52,9 @@ class PromptsTable:
self.db = db self.db = db
self.db.create_tables([Prompt]) self.db.create_tables([Prompt])
def insert_new_prompt(self, user_id: str, def insert_new_prompt(
form_data: PromptForm) -> Optional[PromptModel]: self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
**{ **{
"user_id": user_id, "user_id": user_id,
...@@ -61,7 +62,8 @@ class PromptsTable: ...@@ -61,7 +62,8 @@ class PromptsTable:
"title": form_data.title, "title": form_data.title,
"content": form_data.content, "content": form_data.content,
"timestamp": int(time.time()), "timestamp": int(time.time()),
}) }
)
try: try:
result = Prompt.create(**prompt.model_dump()) result = Prompt.create(**prompt.model_dump())
...@@ -81,13 +83,14 @@ class PromptsTable: ...@@ -81,13 +83,14 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ return [
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select() PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, self, command: str, form_data: PromptForm
form_data: PromptForm) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
query = Prompt.update( query = Prompt.update(
title=form_data.title, title=form_data.title,
......
...@@ -11,6 +11,7 @@ import logging ...@@ -11,6 +11,7 @@ import logging
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
......
...@@ -29,6 +29,7 @@ from apps.web.models.tags import ( ...@@ -29,6 +29,7 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
......
...@@ -10,7 +10,12 @@ import uuid ...@@ -10,7 +10,12 @@ import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -43,7 +48,6 @@ async def set_global_default_models( ...@@ -43,7 +48,6 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions( async def set_global_default_suggestions(
request: Request, request: Request,
......
...@@ -24,9 +24,9 @@ router = APIRouter() ...@@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, async def get_modelfiles(
limit: int = 50, skip: int = 0, limit: int = 50, user=Depends(get_current_user)
user=Depends(get_current_user)): ):
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
...@@ -36,17 +36,16 @@ async def get_modelfiles(skip: int = 0, ...@@ -36,17 +36,16 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": "modelfile": json.loads(modelfile.modelfile),
json.loads(modelfile.modelfile), }
}) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -60,17 +59,18 @@ async def create_new_modelfile(form_data: ModelfileForm, ...@@ -60,17 +59,18 @@ async def create_new_modelfile(form_data: ModelfileForm,
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, async def get_modelfile_by_tag_name(
user=Depends(get_current_user)): form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": "modelfile": json.loads(modelfile.modelfile),
json.loads(modelfile.modelfile), }
}) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -84,8 +84,9 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, ...@@ -84,8 +84,9 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, async def update_modelfile_by_tag_name(
user=Depends(get_admin_user)): form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
...@@ -94,14 +95,15 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, ...@@ -94,14 +95,15 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
} }
modelfile = Modelfiles.update_modelfile_by_tag_name( modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile) form_data.tag_name, updated_modelfile
)
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": "modelfile": json.loads(modelfile.modelfile),
json.loads(modelfile.modelfile), }
}) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -115,7 +117,8 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, ...@@ -115,7 +117,8 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, async def delete_modelfile_by_tag_name(
user=Depends(get_admin_user)): form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result return result
...@@ -16,6 +16,7 @@ from utils.utils import get_current_user, get_password_hash, get_admin_user ...@@ -16,6 +16,7 @@ from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
......
...@@ -26,6 +26,7 @@ except ImportError: ...@@ -26,6 +26,7 @@ except ImportError:
log.warning("dotenv not installed, skipping...") log.warning("dotenv not installed, skipping...")
WEBUI_NAME = "Open WebUI" WEBUI_NAME = "Open WebUI"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
shutil.copyfile("../build/favicon.png", "./static/favicon.png") shutil.copyfile("../build/favicon.png", "./static/favicon.png")
#################################### ####################################
...@@ -116,7 +117,20 @@ else: ...@@ -116,7 +117,20 @@ else:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = ["AUDIO", "CONFIG", "DB", "IMAGES", "LITELLM", "MAIN", "MODELS", "OLLAMA", "OPENAI", "RAG"] log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {} SRC_LOG_LEVELS = {}
...@@ -141,7 +155,7 @@ if CUSTOM_NAME: ...@@ -141,7 +155,7 @@ if CUSTOM_NAME:
data = r.json() data = r.json()
if r.ok: if r.ok:
if "logo" in data: if "logo" in data:
url = ( WEBUI_FAVICON_URL = url = (
f"https://api.openwebui.com{data['logo']}" f"https://api.openwebui.com{data['logo']}"
if data["logo"][0] == "/" if data["logo"][0] == "/"
else data["logo"] else data["logo"]
...@@ -238,7 +252,7 @@ OLLAMA_API_BASE_URL = os.environ.get( ...@@ -238,7 +252,7 @@ OLLAMA_API_BASE_URL = os.environ.get(
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
K8S_FLAG = os.environ.get("K8S_FLAG", "")
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
OLLAMA_BASE_URL = ( OLLAMA_BASE_URL = (
...@@ -251,6 +265,9 @@ if ENV == "prod": ...@@ -251,6 +265,9 @@ if ENV == "prod":
if OLLAMA_BASE_URL == "/ollama": if OLLAMA_BASE_URL == "/ollama":
OLLAMA_BASE_URL = "http://host.docker.internal:11434" OLLAMA_BASE_URL = "http://host.docker.internal:11434"
elif K8S_FLAG:
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
......
...@@ -61,3 +61,6 @@ class ERROR_MESSAGES(str, Enum): ...@@ -61,3 +61,6 @@ class ERROR_MESSAGES(str, Enum):
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
{ {
"version": 0, "version": 0,
"ui": { "ui": {
"prompt_suggestions": [ "default_locale": "en-US",
{ "prompt_suggestions": [
"title": [ {
"Help me study", "title": ["Help me study", "vocabulary for a college entrance exam"],
"vocabulary for a college entrance exam" "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
], },
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option." {
}, "title": ["Give me ideas", "for what to do with my kids' art"],
{ "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
"title": [ },
"Give me ideas", {
"for what to do with my kids' art" "title": ["Tell me a fun fact", "about the Roman Empire"],
], "content": "Tell me a random fun fact about the Roman Empire"
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter." },
}, {
{ "title": ["Show me a code snippet", "of a website's sticky header"],
"title": [ "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
"Tell me a fun fact", }
"about the Roman Empire" ]
], }
"content": "Tell me a random fun fact about the Roman Empire" }
},
{
"title": [
"Show me a code snippet",
"of a website's sticky header"
],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
}
]
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment