"...resnet50_tensorflow.git" did not exist on "543755a0a8e165a7ef09dd9aef638ca953c29583"
Unverified Commit 7071716f authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #408 from ollama-webui/main

rag
parents b050013c c55c8728
...@@ -4,7 +4,6 @@ about: Create a report to help us improve ...@@ -4,7 +4,6 @@ about: Create a report to help us improve
title: '' title: ''
labels: '' labels: ''
assignees: '' assignees: ''
--- ---
# Bug Report # Bug Report
...@@ -31,6 +30,7 @@ assignees: '' ...@@ -31,6 +30,7 @@ assignees: ''
## Reproduction Details ## Reproduction Details
**Confirmation:** **Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md. - [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I have reviewed the troubleshooting.md document. - [ ] I have reviewed the troubleshooting.md document.
- [ ] I have included the browser console logs. - [ ] I have included the browser console logs.
......
...@@ -4,7 +4,6 @@ about: Suggest an idea for this project ...@@ -4,7 +4,6 @@ about: Suggest an idea for this project
title: '' title: ''
labels: '' labels: ''
assignees: '' assignees: ''
--- ---
**Is your feature request related to a problem? Please describe.** **Is your feature request related to a problem? Please describe.**
......
name: Node.js CI name: Python CI
on: on:
push: push:
branches: ['main'] branches: ['main']
pull_request: pull_request:
jobs: jobs:
build: build:
name: 'Fmt, Lint, & Build' name: 'Format Backend'
env: env:
PUBLIC_API_BASE_URL: '' PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
...@@ -14,14 +14,14 @@ jobs: ...@@ -14,14 +14,14 @@ jobs:
node-version: node-version:
- latest - latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Use Node.js ${{ matrix.node-version }} - name: Use Python
uses: actions/setup-node@v3 uses: actions/setup-python@v4
with: - name: Use Bun
node-version: ${{ matrix.node-version }} uses: oven-sh/setup-bun@v1
- run: node --version - name: Install dependencies
- run: npm clean-install run: |
- run: npm run fmt python -m pip install --upgrade pip
#- run: npm run lint pip install yapf
#- run: npm run lint:types - name: Format backend
- run: npm run build run: bun run format:backend
name: Bun CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Format & Build Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- run: bun --version
- name: Install frontend dependencies
run: bun install
- name: Format frontend
run: bun run format
- name: Build frontend
run: bun run build
name: Python CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Lint Backend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
strategy:
matrix:
node-version:
- latest
steps:
- uses: actions/checkout@v4
- name: Use Python
uses: actions/setup-python@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Lint backend
run: bun run lint:backend
name: Bun CI
on:
push:
branches: ['main']
pull_request:
jobs:
build:
name: 'Lint Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Bun
uses: oven-sh/setup-bun@v1
- run: bun --version
- name: Install frontend dependencies
run: bun install --frozen-lockfile
- run: bun run lint:frontend
- run: bun run lint:types
if: success() || failure()
\ No newline at end of file
...@@ -2,12 +2,6 @@ ...@@ -2,12 +2,6 @@
FROM node:alpine as build FROM node:alpine as build
ARG OLLAMA_API_BASE_URL='/ollama/api'
RUN echo $OLLAMA_API_BASE_URL
ENV PUBLIC_API_BASE_URL $OLLAMA_API_BASE_URL
RUN echo $PUBLIC_API_BASE_URL
WORKDIR /app WORKDIR /app
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
...@@ -18,10 +12,13 @@ RUN npm run build ...@@ -18,10 +12,13 @@ RUN npm run build
FROM python:3.11-slim-buster as base FROM python:3.11-slim-buster as base
ARG OLLAMA_API_BASE_URL='/ollama/api'
ENV ENV=prod ENV ENV=prod
ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY ""
ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY" ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY"
WORKDIR /app WORKDIR /app
......
...@@ -198,9 +198,15 @@ While we strongly recommend using our convenient Docker container installation f ...@@ -198,9 +198,15 @@ While we strongly recommend using our convenient Docker container installation f
The Ollama Web UI consists of two primary components: the frontend and the backend (which serves as a reverse proxy, handling static frontend files, and additional features). Both need to be running concurrently for the development environment. The Ollama Web UI consists of two primary components: the frontend and the backend (which serves as a reverse proxy, handling static frontend files, and additional features). Both need to be running concurrently for the development environment.
**Warning: Backend Dependency for Proper Functionality** > [!IMPORTANT]
> The backend is required for proper functionality
### TL;DR 🚀 ### Requirements 📦
- 🐰 [Bun](https://bun.sh) >= 1.0.21 or 🐢 [Node.js](https://nodejs.org/en) >= 20.10
- 🐍 [Python](https://python.org) >= 3.11
### Build and Install 🛠️
Run the following commands to install: Run the following commands to install:
...@@ -211,13 +217,17 @@ cd ollama-webui/ ...@@ -211,13 +217,17 @@ cd ollama-webui/
# Copying required .env file # Copying required .env file
cp -RPp example.env .env cp -RPp example.env .env
# Building Frontend # Building Frontend Using Node
npm i npm i
npm run build npm run build
# or Building Frontend Using Bun
# bun install
# bun run build
# Serving Frontend with the Backend # Serving Frontend with the Backend
cd ./backend cd ./backend
pip install -r requirements.txt pip install -r requirements.txt -U
sh start.sh sh start.sh
``` ```
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
The Ollama WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues. The Ollama WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
- **How it Works**: When you make a request (like `/ollama/api/tags`) from the Ollama WebUI, it doesn’t go directly to the Ollama API. Instead, it first reaches the Ollama WebUI backend. The backend then forwards this request to the Ollama API via the route you define in the `OLLAMA_API_BASE_URL` environment variable. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend. - **How it Works**: The Ollama WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Ollama WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend.
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer. - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
...@@ -27,6 +27,6 @@ docker run -d --network=host -v ollama-webui:/app/backend/data -e OLLAMA_API_BAS ...@@ -27,6 +27,6 @@ docker run -d --network=host -v ollama-webui:/app/backend/data -e OLLAMA_API_BAS
1. **Verify Ollama URL Format**: 1. **Verify Ollama URL Format**:
- When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups). - When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups).
- In the Ollama WebUI, navigate to "Settings" > "General". - In the Ollama WebUI, navigate to "Settings" > "General".
- Confirm that the Ollama Server URL is correctly set to `/ollama/api`, including the `/api` suffix. - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix.
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord. By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
from flask import Flask, request, Response, jsonify from fastapi import FastAPI, Request, Response, HTTPException, Depends
from flask_cors import CORS from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
import requests import requests
import json import json
from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__) app = FastAPI()
CORS( app.add_middleware(
app CORSMiddleware,
) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the target server URL app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
TARGET_SERVER_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.route("/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE"])
@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
def proxy(path):
# Combine the base URL of the target server with the requested path
target_url = f"{TARGET_SERVER_URL}/{path}"
print(target_url)
# Get data from the original request @app.get("/url")
data = request.get_data() async def get_ollama_api_url(user=Depends(get_current_user)):
headers = dict(request.headers) if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
url: str
# Basic RBAC support
if WEBUI_AUTH: @app.post("/url/update")
if "Authorization" in headers: async def update_ollama_api_url(
_, credentials = headers["Authorization"].split() form_data: UrlUpdateForm, user=Depends(get_current_user)
token_data = decode_token(credentials) ):
if token_data is None or "email" not in token_data: if user and user.role == "admin":
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
user = Users.get_user_by_email(token_data["email"])
if user:
# Only user and admin roles can access
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
# Only admin role can perform actions above
if user.role == "admin":
pass
else:
return (
jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
401,
)
else:
pass
else:
return jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else: else:
pass raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
r = None
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None) headers.pop("Host", None)
headers.pop("Authorization", None) headers.pop("Authorization", None)
headers.pop("Origin", None) headers.pop("Origin", None)
headers.pop("Referer", None) headers.pop("Referer", None)
try: r = None
# Make a request to the target server
r = requests.request(
method=request.method,
url=target_url,
data=data,
headers=headers,
stream=True, # Enable streaming for server-sent events
)
r.raise_for_status()
# Proxy the target server's response to the client
def generate():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
response = Response(generate(), status=r.status_code)
# Copy headers from the target server's response to the client's response def get_request():
for key, value in r.headers.items(): nonlocal r
response.headers[key] = value try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
return response try:
return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Ollama WebUI: Server Connection Error"
if r != None: if r is not None:
print(r.text) try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
print(res) except:
error_detail = f"Ollama: {e}"
return (
jsonify( raise HTTPException(
{ status_code=r.status_code if r else 500,
"detail": error_detail, detail=error_detail,
"message": str(e),
}
),
400,
) )
if __name__ == "__main__":
app.run(debug=True)
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
import aiohttp
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
url: str
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url)
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None)
headers.pop("Authorization", None)
headers.pop("Origin", None)
headers.pop("Referer", None)
session = aiohttp.ClientSession()
response = None
try:
response = await session.request(
request.method, target_url, data=body, headers=headers
)
print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
async def generate():
async for line in response.content:
print(line)
yield line
await session.close()
return StreamingResponse(generate(), response.status)
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if response is not None:
try:
res = await response.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException(
status_code=response.status if response else 500,
detail=error_detail,
)
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
class UrlUpdateForm(BaseModel):
url: str
class KeyUpdateForm(BaseModel):
key: str
@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
# headers = dict(request.headers)
# print(headers)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"],
response_data["data"]))
return response_data
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
...@@ -22,10 +22,11 @@ app.add_middleware( ...@@ -22,10 +22,11 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import uuid import uuid
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
from utils.utils import ( from utils.utils import (
verify_password, verify_password,
...@@ -123,6 +122,15 @@ class AuthsTable: ...@@ -123,6 +122,15 @@ class AuthsTable:
except: except:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool:
try:
query = Auth.update(email=email).where(Auth.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
try: try:
# Delete User # Delete User
......
...@@ -3,14 +3,12 @@ from typing import List, Union, Optional ...@@ -3,14 +3,12 @@ from typing import List, Union, Optional
from peewee import * from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
from apps.web.internal.db import DB from apps.web.internal.db import DB
#################### ####################
# Chat DB Schema # Chat DB Schema
#################### ####################
...@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): ...@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
db.create_tables([Chat]) db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str,
form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": form_data.chat["title"] "title": form_data.chat["title"] if "title" in
if "title" in form_data.chat form_data.chat else "New Chat",
else "New Chat",
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
result = Chat.create(**chat.model_dump()) result = Chat.create(**chat.model_dump())
return chat if result else None return chat if result else None
...@@ -111,27 +109,25 @@ class ChatTable: ...@@ -111,27 +109,25 @@ class ChatTable:
except: except:
return None return None
def get_chat_lists_by_user_id( def get_chat_lists_by_user_id(self,
self, user_id: str, skip: int = 0, limit: int = 50 user_id: str,
) -> List[ChatModel]: skip: int = 0,
limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
for chat in Chat.select() Chat.user_id == user_id).order_by(Chat.timestamp.desc())
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
for chat in Chat.select() Chat.user_id == user_id).order_by(Chat.timestamp.desc())
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str,
user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat)) return ChatModel(**model_to_dict(chat))
...@@ -146,7 +142,8 @@ class ChatTable: ...@@ -146,7 +142,8 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query = Chat.delete().where((Chat.id == id)
& (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True
......
...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel): ...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class ModelfilesTable: class ModelfilesTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Modelfile]) self.db.create_tables([Modelfile])
def insert_new_modelfile( def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm self, user_id: str,
) -> Optional[ModelfileModel]: form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile: if "tagName" in form_data.modelfile:
modelfile = ModelfileModel( modelfile = ModelfileModel(
**{ **{
...@@ -72,8 +73,7 @@ class ModelfilesTable: ...@@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"], "tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile), "modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Modelfile.create(**modelfile.model_dump()) result = Modelfile.create(**modelfile.model_dump())
...@@ -87,28 +87,29 @@ class ModelfilesTable: ...@@ -87,28 +87,29 @@ class ModelfilesTable:
else: else:
return None return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try: try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name) modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile)) return ModelfileModel(**model_to_dict(modelfile))
except: except:
return None return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [ return [
ModelfileResponse( ModelfileResponse(
**{ **{
**model_to_dict(modelfile), **model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) }) for modelfile in Modelfile.select()
for modelfile in Modelfile.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_modelfile_by_tag_name( def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
) -> Optional[ModelfileModel]:
try: try:
query = Modelfile.update( query = Modelfile.update(
modelfile=json.dumps(modelfile), modelfile=json.dumps(modelfile),
......
...@@ -47,13 +47,13 @@ class PromptForm(BaseModel): ...@@ -47,13 +47,13 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Prompt]) self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(self, user_id: str,
self, user_id: str, form_data: PromptForm form_data: PromptForm) -> Optional[PromptModel]:
) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
**{ **{
"user_id": user_id, "user_id": user_id,
...@@ -61,8 +61,7 @@ class PromptsTable: ...@@ -61,8 +61,7 @@ class PromptsTable:
"title": form_data.title, "title": form_data.title,
"content": form_data.content, "content": form_data.content,
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Prompt.create(**prompt.model_dump()) result = Prompt.create(**prompt.model_dump())
...@@ -82,14 +81,13 @@ class PromptsTable: ...@@ -82,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ return [
PromptModel(**model_to_dict(prompt)) PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
for prompt in Prompt.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str,
) -> Optional[PromptModel]: form_data: PromptForm) -> Optional[PromptModel]:
try: try:
query = Prompt.update( query = Prompt.update(
title=form_data.title, title=form_data.title,
......
...@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url ...@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
from apps.web.models.chats import Chats from apps.web.models.chats import Chats
#################### ####################
# User DB Schema # User DB Schema
#################### ####################
...@@ -45,6 +44,13 @@ class UserRoleUpdateForm(BaseModel): ...@@ -45,6 +44,13 @@ class UserRoleUpdateForm(BaseModel):
role: str role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable: class UsersTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
...@@ -102,6 +108,16 @@ class UsersTable: ...@@ -102,6 +108,16 @@ class UsersTable:
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
try: try:
# Delete User Chats # Delete User Chats
......
...@@ -8,7 +8,6 @@ from pydantic import BaseModel ...@@ -8,7 +8,6 @@ from pydantic import BaseModel
import time import time
import uuid import uuid
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
...@@ -19,12 +18,10 @@ from apps.web.models.auths import ( ...@@ -19,12 +18,10 @@ from apps.web.models.auths import (
) )
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
############################ ############################
...@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)): ...@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password( async def update_password(form_data: UpdatePasswordForm,
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) session_user=Depends(get_current_user)):
):
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
...@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
try: try:
role = "admin" if Users.get_num_users() == 0 else "pending" role = "admin" if Users.get_num_users() == 0 else "pending"
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(form_data.email.lower(),
form_data.email.lower(), hashed, form_data.name, role hashed, form_data.name, role)
)
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(data={"email": user.email})
...@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
} }
else: else:
raise HTTPException( raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
)
except Exception as err: except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) raise HTTPException(500,
detail=ERROR_MESSAGES.DEFAULT(err))
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
......
...@@ -17,8 +17,7 @@ from apps.web.models.chats import ( ...@@ -17,8 +17,7 @@ from apps.web.models.chats import (
) )
from utils.utils import ( from utils.utils import (
bearer_scheme, bearer_scheme, )
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -30,8 +29,7 @@ router = APIRouter() ...@@ -30,8 +29,7 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats( async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_current_user), skip: int = 0, limit: int = 50):
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
...@@ -43,8 +41,9 @@ async def get_user_chats( ...@@ -43,8 +41,9 @@ async def get_user_chats(
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)): async def get_all_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{
for chat in Chats.get_all_chats_by_user_id(user.id) **chat.model_dump(), "chat": json.loads(chat.chat)
}) for chat in Chats.get_all_chats_by_user_id(user.id)
] ]
...@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND detail=ERROR_MESSAGES.NOT_FOUND)
)
############################ ############################
...@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(id: str,
id: str, form_data: ChatForm, user=Depends(get_current_user) form_data: ChatForm,
): user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment