main.py 3.39 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

import requests
import json
from pydantic import BaseModel

from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY


class UrlUpdateForm(BaseModel):
    url: str


class KeyUpdateForm(BaseModel):
    key: str


@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
    if user and user.role == "admin":
        app.state.OPENAI_API_BASE_URL = form_data.url
        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
    if user and user.role == "admin":
        app.state.OPENAI_API_KEY = form_data.key
        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"

    body = await request.body()
    headers = dict(request.headers)

    if user.role not in ["user", "admin"]:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)

    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)

    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"

    try:
        r = requests.request(
            method=request.method,
            url=target_url,
            data=body,
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        return StreamingResponse(
            r.iter_content(chunk_size=8192),
            status_code=r.status_code,
            headers=dict(r.headers),
        )
    except Exception as e:
        print(e)
        error_detail = "Ollama WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)