main.py 3.01 KB
Newer Older
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6

import requests
import json
7
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
8

9
10
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
11
from utils.utils import decode_token, get_current_user
12
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
16
17
18
19
20
21
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
22

23
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27


28
29
30
31
32
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
33
34
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
35

Timothy J. Baek's avatar
Timothy J. Baek committed
36

37
38
39
40
41
class UrlUpdateForm(BaseModel):
    url: str


@app.post("/url/update")
42
43
async def update_ollama_api_url(form_data: UrlUpdateForm,
                                user=Depends(get_current_user)):
44
45
46
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
47
    else:
48
49
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
50

51

52
53
54
55
56
57
58
59
60
61
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"

    body = await request.body()
    headers = dict(request.headers)

    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
62
63
                raise HTTPException(status_code=401,
                                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
64
    else:
65
66
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
67

Timothy J. Baek's avatar
Timothy J. Baek committed
68
69
70
71
    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
    try:
74
        r = requests.request(
75
76
            method=request.method,
            url=target_url,
77
            data=body,
78
            headers=headers,
79
            stream=True,
80
81
        )

82
        r.raise_for_status()
83

84
85
86
87
88
        return StreamingResponse(
            r.iter_content(chunk_size=8192),
            status_code=r.status_code,
            headers=dict(r.headers),
        )
89
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
90
        print(e)
91
        error_detail = "Ollama WebUI: Server Connection Error"
92
93
94
95
96
97
98
99
100
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)