main.py 5.79 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
3
from fastapi.responses import StreamingResponse, JSONResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

import requests
import json
from pydantic import BaseModel

from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY


class UrlUpdateForm(BaseModel):
    url: str


class KeyUpdateForm(BaseModel):
    key: str


@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
    else:
ThatOneCalculator's avatar
ThatOneCalculator committed
40
41
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
Timothy J. Baek's avatar
Timothy J. Baek committed
42
43
44


@app.post("/url/update")
ThatOneCalculator's avatar
ThatOneCalculator committed
45
46
async def update_openai_url(form_data: UrlUpdateForm,
                            user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
47
48
49
50
    if user and user.role == "admin":
        app.state.OPENAI_API_BASE_URL = form_data.url
        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
    else:
ThatOneCalculator's avatar
ThatOneCalculator committed
51
52
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
55
56
57
58
59


@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
    else:
ThatOneCalculator's avatar
ThatOneCalculator committed
60
61
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64


@app.post("/key/update")
ThatOneCalculator's avatar
ThatOneCalculator committed
65
66
async def update_openai_key(form_data: KeyUpdateForm,
                            user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
    if user and user.role == "admin":
        app.state.OPENAI_API_KEY = form_data.key
        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
    else:
ThatOneCalculator's avatar
ThatOneCalculator committed
71
72
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
78
    print(target_url, app.state.OPENAI_API_KEY)
Timothy J. Baek's avatar
Timothy J. Baek committed
79
80

    if user.role not in ["user", "admin"]:
ThatOneCalculator's avatar
ThatOneCalculator committed
81
82
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
83
    if app.state.OPENAI_API_KEY == "":
ThatOneCalculator's avatar
ThatOneCalculator committed
84
85
        raise HTTPException(status_code=401,
                            detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
        body_str = (await request.body()).decode('utf-8')
    except UnicodeDecodeError as e:
        print("Error decoding request body:", e)
        raise HTTPException(status_code=400, detail="Invalid request body")
    # Check if the body is not empty
    if body_str:
        try:
            
            body_dict = json.loads(body_str)
        except json.JSONDecodeError as e:
            print("Error loading request body into a dictionary:", e)
            raise HTTPException(status_code=400, detail="Invalid JSON in request body")
        
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 10000
        # This is a workaround until OpenAI fixes the issue with this model
        if body_dict.get("model") == "gpt-4-vision-preview":
            body_dict["max_tokens"] = 10000
            print("Modified body_dict:", body_dict)
        
        # Try to convert the modified body back to JSON
        try:
            # Convert the modified body back to JSON
            body_json = json.dumps(body_dict)
        except TypeError as e:
            print("Error converting modified body to JSON:", e)
            raise HTTPException(status_code=500, detail="Internal server error")
    else:
        body_json = body_str  # If the body is empty, use it as is

Timothy J. Baek's avatar
Timothy J. Baek committed
118

119
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
120
    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
121
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
124
125
126

    try:
        r = requests.request(
            method=request.method,
            url=target_url,
127
            data=body_json,
Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
131
132
133
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            # For non-SSE, read the response and return it
            # response_data = (
            #     r.json()
            #     if r.headers.get("Content-Type", "")
            #     == "application/json"
            #     else r.text
            # )

            response_data = r.json()

            print(type(response_data))

            if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
                response_data["data"] = list(
ThatOneCalculator's avatar
ThatOneCalculator committed
156
157
                    filter(lambda model: "gpt" in model["id"],
                           response_data["data"]))
158
159

            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
162
163
164
165
166
167
168
169
170
171
    except Exception as e:
        print(e)
        error_detail = "Ollama WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)