Unverified Commit 28682aad authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #368 from ThatOneCalculator/bun

feat:  bun support, backend lint, frontend & backend CI
parents e6c65088 cddfd113
...@@ -4,7 +4,6 @@ about: Create a report to help us improve ...@@ -4,7 +4,6 @@ about: Create a report to help us improve
title: '' title: ''
labels: '' labels: ''
assignees: '' assignees: ''
--- ---
# Bug Report # Bug Report
...@@ -31,6 +30,7 @@ assignees: '' ...@@ -31,6 +30,7 @@ assignees: ''
## Reproduction Details ## Reproduction Details
**Confirmation:** **Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md. - [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I have reviewed the troubleshooting.md document. - [ ] I have reviewed the troubleshooting.md document.
- [ ] I have included the browser console logs. - [ ] I have included the browser console logs.
......
...@@ -4,7 +4,6 @@ about: Suggest an idea for this project ...@@ -4,7 +4,6 @@ about: Suggest an idea for this project
title: '' title: ''
labels: '' labels: ''
assignees: '' assignees: ''
--- ---
**Is your feature request related to a problem? Please describe.** **Is your feature request related to a problem? Please describe.**
......
name: Node.js CI name: Python CI
on: on:
push: push:
branches: ['main'] branches: ['main']
pull_request: pull_request:
jobs: jobs:
build: build:
name: 'Fmt, Lint, & Build' name: 'Format Backend'
env: env:
PUBLIC_API_BASE_URL: '' PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
...@@ -14,14 +14,14 @@ jobs: ...@@ -14,14 +14,14 @@ jobs:
node-version: node-version:
- latest - latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Use Node.js ${{ matrix.node-version }} - name: Use Python
uses: actions/setup-node@v3 uses: actions/setup-python@v4
with: - name: Use Bun
node-version: ${{ matrix.node-version }} uses: oven-sh/setup-bun@v1
- run: node --version - name: Install dependencies
- run: npm clean-install run: |
- run: npm run fmt python -m pip install --upgrade pip
#- run: npm run lint pip install yapf
#- run: npm run lint:types - name: Format backend
- run: npm run build run: bun run format:backend
name: Bun CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Format & Build Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- run: bun --version
- name: Install frontend dependencies
run: bun install --frozen-lockfile
- name: Format frontend
run: bun run format
- name: Build frontend
run: bun run build
name: Python CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Lint Backend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
strategy:
matrix:
node-version:
- latest
steps:
- uses: actions/checkout@v4
- name: Use Python
uses: actions/setup-python@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Lint backend
run: bun run lint:backend
name: Bun CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Lint Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- run: bun --version
- name: Install frontend dependencies
run: bun install --frozen-lockfile
- run: bun run lint:frontend
- run: bun run lint:types
if: success() || failure()
\ No newline at end of file
...@@ -198,9 +198,15 @@ While we strongly recommend using our convenient Docker container installation f ...@@ -198,9 +198,15 @@ While we strongly recommend using our convenient Docker container installation f
The Ollama Web UI consists of two primary components: the frontend and the backend (which serves as a reverse proxy, handling static frontend files, and additional features). Both need to be running concurrently for the development environment. The Ollama Web UI consists of two primary components: the frontend and the backend (which serves as a reverse proxy, handling static frontend files, and additional features). Both need to be running concurrently for the development environment.
**Warning: Backend Dependency for Proper Functionality** > [!IMPORTANT]
> The backend is required for proper functionality
### TL;DR 🚀 ### Requirements 📦
- 🐰 [Bun](https://bun.sh) >= 1.0.21 or 🐢 [Node.js](https://nodejs.org/en) >= 20.10
- 🐍 [Python](https://python.org) >= 3.11
### Build and Install 🛠️
Run the following commands to install: Run the following commands to install:
...@@ -211,10 +217,14 @@ cd ollama-webui/ ...@@ -211,10 +217,14 @@ cd ollama-webui/
# Copying required .env file # Copying required .env file
cp -RPp example.env .env cp -RPp example.env .env
# Building Frontend # Building Frontend Using Node
npm i npm i
npm run build npm run build
# or Building Frontend Using Bun
# bun install
# bun run build
# Serving Frontend with the Backend # Serving Frontend with the Backend
cd ./backend cd ./backend
pip install -r requirements.txt pip install -r requirements.txt
......
...@@ -30,7 +30,8 @@ async def get_ollama_api_url(user=Depends(get_current_user)): ...@@ -30,7 +30,8 @@ async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
...@@ -38,14 +39,14 @@ class UrlUpdateForm(BaseModel): ...@@ -38,14 +39,14 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update") @app.post("/url/update")
async def update_ollama_api_url( async def update_ollama_api_url(form_data: UrlUpdateForm,
form_data: UrlUpdateForm, user=Depends(get_current_user) user=Depends(get_current_user)):
):
if user and user.role == "admin": if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
...@@ -58,11 +59,11 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -58,11 +59,11 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if user.role in ["user", "admin"]: if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]: if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(status_code=401,
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
)
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None) headers.pop("Host", None)
headers.pop("Authorization", None) headers.pop("Authorization", None)
......
from flask import Flask, request, Response, jsonify from flask import Flask, request, Response, jsonify
from flask_cors import CORS from flask_cors import CORS
import requests import requests
import json import json
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token from utils.utils import decode_token
...@@ -77,7 +75,9 @@ def update_ollama_api_url(): ...@@ -77,7 +75,9 @@ def update_ollama_api_url():
) )
@app.route("/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE"]) @app.route("/",
defaults={"path": ""},
methods=["GET", "POST", "PUT", "DELETE"])
@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"]) @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
def proxy(path): def proxy(path):
# Combine the base URL of the target server with the requested path # Combine the base URL of the target server with the requested path
...@@ -106,13 +106,17 @@ def proxy(path): ...@@ -106,13 +106,17 @@ def proxy(path):
pass pass
else: else:
return ( return (
jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), jsonify({
"detail":
ERROR_MESSAGES.ACCESS_PROHIBITED
}),
401, 401,
) )
else: else:
pass pass
else: else:
return jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401 return jsonify(
{"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
else: else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else: else:
...@@ -162,12 +166,10 @@ def proxy(path): ...@@ -162,12 +166,10 @@ def proxy(path):
print(res) print(res)
return ( return (
jsonify( jsonify({
{
"detail": error_detail, "detail": error_detail,
"message": str(e), "message": str(e),
} }),
),
400, 400,
) )
......
...@@ -37,16 +37,19 @@ async def get_openai_url(user=Depends(get_current_user)): ...@@ -37,16 +37,19 @@ async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update") @app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): async def update_openai_url(form_data: UrlUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key") @app.get("/key")
...@@ -54,16 +57,19 @@ async def get_openai_key(user=Depends(get_current_user)): ...@@ -54,16 +57,19 @@ async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update") @app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): async def update_openai_key(form_data: KeyUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
...@@ -72,9 +78,11 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -72,9 +78,11 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
print(target_url, app.state.OPENAI_API_KEY) print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]: if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "": if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body() body = await request.body()
# headers = dict(request.headers) # headers = dict(request.headers)
...@@ -117,8 +125,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -117,8 +125,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list( response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"]) filter(lambda model: "gpt" in model["id"],
) response_data["data"]))
return response_data return response_data
except Exception as e: except Exception as e:
......
...@@ -22,10 +22,11 @@ app.add_middleware( ...@@ -22,10 +22,11 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import uuid import uuid
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
from utils.utils import ( from utils.utils import (
verify_password, verify_password,
...@@ -76,20 +75,26 @@ class SignupForm(BaseModel): ...@@ -76,20 +75,26 @@ class SignupForm(BaseModel):
class AuthsTable: class AuthsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Auth]) self.db.create_tables([Auth])
def insert_new_auth( def insert_new_auth(self,
self, email: str, password: str, name: str, role: str = "pending" email: str,
) -> Optional[UserModel]: password: str,
name: str,
role: str = "pending") -> Optional[UserModel]:
print("insert_new_auth") print("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
auth = AuthModel( auth = AuthModel(**{
**{"id": id, "email": email, "password": password, "active": True} "id": id,
) "email": email,
"password": password,
"active": True
})
result = Auth.create(**auth.model_dump()) result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, role) user = Users.insert_new_user(id, name, email, role)
...@@ -99,7 +104,8 @@ class AuthsTable: ...@@ -99,7 +104,8 @@ class AuthsTable:
else: else:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str,
password: str) -> Optional[UserModel]:
print("authenticate_user", email) print("authenticate_user", email)
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) auth = Auth.get(Auth.email == email, Auth.active == True)
...@@ -131,7 +137,8 @@ class AuthsTable: ...@@ -131,7 +137,8 @@ class AuthsTable:
if result: if result:
# Delete Auth # Delete Auth
query = Auth.delete().where(Auth.id == id) query = Auth.delete().where(Auth.id == id)
query.execute() # Remove the rows, return number of rows removed. query.execute(
) # Remove the rows, return number of rows removed.
return True return True
else: else:
......
...@@ -3,14 +3,12 @@ from typing import List, Union, Optional ...@@ -3,14 +3,12 @@ from typing import List, Union, Optional
from peewee import * from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
from apps.web.internal.db import DB from apps.web.internal.db import DB
#################### ####################
# Chat DB Schema # Chat DB Schema
#################### ####################
...@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): ...@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
db.create_tables([Chat]) db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str,
form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": form_data.chat["title"] "title": form_data.chat["title"] if "title" in
if "title" in form_data.chat form_data.chat else "New Chat",
else "New Chat",
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
result = Chat.create(**chat.model_dump()) result = Chat.create(**chat.model_dump())
return chat if result else None return chat if result else None
...@@ -111,27 +109,25 @@ class ChatTable: ...@@ -111,27 +109,25 @@ class ChatTable:
except: except:
return None return None
def get_chat_lists_by_user_id( def get_chat_lists_by_user_id(self,
self, user_id: str, skip: int = 0, limit: int = 50 user_id: str,
) -> List[ChatModel]: skip: int = 0,
limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
for chat in Chat.select() Chat.user_id == user_id).order_by(Chat.timestamp.desc())
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
for chat in Chat.select() Chat.user_id == user_id).order_by(Chat.timestamp.desc())
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str,
user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat)) return ChatModel(**model_to_dict(chat))
...@@ -146,7 +142,8 @@ class ChatTable: ...@@ -146,7 +142,8 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query = Chat.delete().where((Chat.id == id)
& (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True
......
...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel): ...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class ModelfilesTable: class ModelfilesTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Modelfile]) self.db.create_tables([Modelfile])
def insert_new_modelfile( def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm self, user_id: str,
) -> Optional[ModelfileModel]: form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile: if "tagName" in form_data.modelfile:
modelfile = ModelfileModel( modelfile = ModelfileModel(
**{ **{
...@@ -72,8 +73,7 @@ class ModelfilesTable: ...@@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"], "tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile), "modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Modelfile.create(**modelfile.model_dump()) result = Modelfile.create(**modelfile.model_dump())
...@@ -87,28 +87,29 @@ class ModelfilesTable: ...@@ -87,28 +87,29 @@ class ModelfilesTable:
else: else:
return None return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try: try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name) modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile)) return ModelfileModel(**model_to_dict(modelfile))
except: except:
return None return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [ return [
ModelfileResponse( ModelfileResponse(
**{ **{
**model_to_dict(modelfile), **model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) }) for modelfile in Modelfile.select()
for modelfile in Modelfile.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_modelfile_by_tag_name( def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
) -> Optional[ModelfileModel]:
try: try:
query = Modelfile.update( query = Modelfile.update(
modelfile=json.dumps(modelfile), modelfile=json.dumps(modelfile),
......
...@@ -47,13 +47,13 @@ class PromptForm(BaseModel): ...@@ -47,13 +47,13 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Prompt]) self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(self, user_id: str,
self, user_id: str, form_data: PromptForm form_data: PromptForm) -> Optional[PromptModel]:
) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
**{ **{
"user_id": user_id, "user_id": user_id,
...@@ -61,8 +61,7 @@ class PromptsTable: ...@@ -61,8 +61,7 @@ class PromptsTable:
"title": form_data.title, "title": form_data.title,
"content": form_data.content, "content": form_data.content,
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Prompt.create(**prompt.model_dump()) result = Prompt.create(**prompt.model_dump())
...@@ -82,14 +81,13 @@ class PromptsTable: ...@@ -82,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ return [
PromptModel(**model_to_dict(prompt)) PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
for prompt in Prompt.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str,
) -> Optional[PromptModel]: form_data: PromptForm) -> Optional[PromptModel]:
try: try:
query = Prompt.update( query = Prompt.update(
title=form_data.title, title=form_data.title,
......
...@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url ...@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
from apps.web.models.chats import Chats from apps.web.models.chats import Chats
#################### ####################
# User DB Schema # User DB Schema
#################### ####################
...@@ -46,13 +45,16 @@ class UserRoleUpdateForm(BaseModel): ...@@ -46,13 +45,16 @@ class UserRoleUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([User]) self.db.create_tables([User])
def insert_new_user( def insert_new_user(self,
self, id: str, name: str, email: str, role: str = "pending" id: str,
) -> Optional[UserModel]: name: str,
email: str,
role: str = "pending") -> Optional[UserModel]:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
...@@ -61,8 +63,7 @@ class UsersTable: ...@@ -61,8 +63,7 @@ class UsersTable:
"role": role, "role": role,
"profile_image_url": get_gravatar_url(email), "profile_image_url": get_gravatar_url(email),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
result = User.create(**user.model_dump()) result = User.create(**user.model_dump())
if result: if result:
return user return user
...@@ -92,7 +93,8 @@ class UsersTable: ...@@ -92,7 +93,8 @@ class UsersTable:
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() return User.select().count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str,
role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) query = User.update(role=role).where(User.id == id)
query.execute() query.execute()
...@@ -110,7 +112,8 @@ class UsersTable: ...@@ -110,7 +112,8 @@ class UsersTable:
if result: if result:
# Delete User # Delete User
query = User.delete().where(User.id == id) query = User.delete().where(User.id == id)
query.execute() # Remove the rows, return number of rows removed. query.execute(
) # Remove the rows, return number of rows removed.
return True return True
else: else:
......
...@@ -8,7 +8,6 @@ from pydantic import BaseModel ...@@ -8,7 +8,6 @@ from pydantic import BaseModel
import time import time
import uuid import uuid
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
...@@ -19,12 +18,10 @@ from apps.web.models.auths import ( ...@@ -19,12 +18,10 @@ from apps.web.models.auths import (
) )
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
############################ ############################
...@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)): ...@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password( async def update_password(form_data: UpdatePasswordForm,
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) session_user=Depends(get_current_user)):
):
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
...@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
try: try:
role = "admin" if Users.get_num_users() == 0 else "pending" role = "admin" if Users.get_num_users() == 0 else "pending"
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(form_data.email.lower(),
form_data.email.lower(), hashed, form_data.name, role hashed, form_data.name, role)
)
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(data={"email": user.email})
...@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
} }
else: else:
raise HTTPException( raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
)
except Exception as err: except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) raise HTTPException(500,
detail=ERROR_MESSAGES.DEFAULT(err))
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
......
...@@ -17,8 +17,7 @@ from apps.web.models.chats import ( ...@@ -17,8 +17,7 @@ from apps.web.models.chats import (
) )
from utils.utils import ( from utils.utils import (
bearer_scheme, bearer_scheme, )
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -30,8 +29,7 @@ router = APIRouter() ...@@ -30,8 +29,7 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats( async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_current_user), skip: int = 0, limit: int = 50):
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
...@@ -43,8 +41,9 @@ async def get_user_chats( ...@@ -43,8 +41,9 @@ async def get_user_chats(
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)): async def get_all_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{
for chat in Chats.get_all_chats_by_user_id(user.id) **chat.model_dump(), "chat": json.loads(chat.chat)
}) for chat in Chats.get_all_chats_by_user_id(user.id)
] ]
...@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND detail=ERROR_MESSAGES.NOT_FOUND)
)
############################ ############################
...@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(id: str,
id: str, form_data: ChatForm, user=Depends(get_current_user) form_data: ChatForm,
): user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
......
...@@ -10,7 +10,6 @@ import uuid ...@@ -10,7 +10,6 @@ import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel): ...@@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel):
@router.post("/default/models", response_model=str) @router.post("/default/models", response_model=str)
async def set_global_default_models( async def set_global_default_models(request: Request,
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) form_data: SetDefaultModelsForm,
): user=Depends(get_current_user)):
if user.role == "admin": if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS return request.app.state.DEFAULT_MODELS
......
...@@ -24,7 +24,9 @@ router = APIRouter() ...@@ -24,7 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
...@@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren ...@@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile( async def create_new_modelfile(form_data: ModelfileForm,
form_data: ModelfileForm, user=Depends(get_current_user) user=Depends(get_current_user)):
):
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -49,9 +50,9 @@ async def create_new_modelfile( ...@@ -49,9 +50,9 @@ async def create_new_modelfile(
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -65,16 +66,17 @@ async def create_new_modelfile( ...@@ -65,16 +66,17 @@ async def create_new_modelfile(
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)): async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend ...@@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
form_data: ModelfileUpdateForm, user=Depends(get_current_user) user=Depends(get_current_user)):
):
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name( ...@@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name(
} }
modelfile = Modelfiles.update_modelfile_by_tag_name( modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile form_data.tag_name, updated_modelfile)
)
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name( ...@@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
form_data: ModelfileTagNameForm, user=Depends(get_current_user) user=Depends(get_current_user)):
):
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment