main.py 2.91 KB
Newer Older
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6

import requests
import json
7
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
8

9
10
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
11
from utils.utils import decode_token, get_current_user
12
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
16
17
18
19
20
21
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
22

23
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27


28
29
30
31
32
33
34
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)

Timothy J. Baek's avatar
Timothy J. Baek committed
35

36
37
38
39
40
41
42
43
44
45
46
class UrlUpdateForm(BaseModel):
    url: str


@app.post("/url/update")
async def update_ollama_api_url(
    form_data: UrlUpdateForm, user=Depends(get_current_user)
):
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
47
    else:
48
49
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)

50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"

    body = await request.body()
    headers = dict(request.headers)

    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
                raise HTTPException(
                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                )
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)
Timothy J. Baek's avatar
Timothy J. Baek committed
71

72
    try:
73
        r = requests.request(
74
75
            method=request.method,
            url=target_url,
76
            data=body,
77
            headers=headers,
78
            stream=True,
79
80
        )

81
        r.raise_for_status()
82

83
84
85
86
87
        return StreamingResponse(
            r.iter_content(chunk_size=8192),
            status_code=r.status_code,
            headers=dict(r.headers),
        )
88
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
89
        print(e)
90
        error_detail = "Ollama WebUI: Server Connection Error"
91
92
93
94
95
96
97
98
99
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)