main.py 3.21 KB
Newer Older
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
from fastapi.concurrency import run_in_threadpool
Timothy J. Baek's avatar
Timothy J. Baek committed
5
6
7

import requests
import json
8
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
9

10
11
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
12
from utils.utils import decode_token, get_current_user
13
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
Timothy J. Baek's avatar
Timothy J. Baek committed
14

15
16
17
18
19
20
21
22
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
23

24
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
25

26
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28


29
30
31
32
33
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
34
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
35

Timothy J. Baek's avatar
Timothy J. Baek committed
36

37
38
39
40
41
class UrlUpdateForm(BaseModel):
    url: str


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
42
43
44
async def update_ollama_api_url(
    form_data: UrlUpdateForm, user=Depends(get_current_user)
):
45
46
47
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
48
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
49
50
51
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


52
53
54
55
56
57
58
59
60
61
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"

    body = await request.body()
    headers = dict(request.headers)

    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64
                raise HTTPException(
                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                )
65
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
66
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
67

68
69
70
71
    headers.pop("host", None)
    headers.pop("authorization", None)
    headers.pop("origin", None)
    headers.pop("referer", None)
Timothy J. Baek's avatar
Timothy J. Baek committed
72

Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    r = None

    def get_request():
        nonlocal r
        try:
            r = requests.request(
                method=request.method,
                url=target_url,
                data=body,
                headers=headers,
                stream=True,
            )

            r.raise_for_status()

            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        except Exception as e:
            raise e
95

Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
    try:
        return await run_in_threadpool(get_request)
98
    except Exception as e:
99
        error_detail = "Ollama WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
100
        if r is not None:
101
            try:
Timothy J. Baek's avatar
Timothy J. Baek committed
102
                res = r.json()
103
104
105
106
107
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
108
        raise HTTPException(
Timothy J. Baek's avatar
Timothy J. Baek committed
109
            status_code=r.status_code if r else 500,
Timothy J. Baek's avatar
Timothy J. Baek committed
110
111
            detail=error_detail,
        )