main.py 9.01 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
Timothy J. Baek's avatar
Timothy J. Baek committed
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
11
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
17
18
19
20
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
24
25
26

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
29
30
31
32
33
34
35
36

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
40
41
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
42

Timothy J. Baek's avatar
Timothy J. Baek committed
43
44
45
46
47
48
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
49

Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
52
53


Timothy J. Baek's avatar
Timothy J. Baek committed
54
55
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57


Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
60
61


Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
65

Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72


Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77
78
79
80
81
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83


Timothy J. Baek's avatar
Timothy J. Baek committed
84
@app.post("/audio/speech")
85
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    idx = None
    try:
        idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
        headers["Content-Type"] = "application/json"

        try:
            r = requests.post(
                url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
114

Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
119

Timothy J. Baek's avatar
Timothy J. Baek committed
120
121
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
122

Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
125

Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(status_code=r.status_code, detail=error_detail)

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142


Timothy J. Baek's avatar
Timothy J. Baek committed
143
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
144
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
        return None


def merge_models_lists(model_lists):
    merged_list = []

    for idx, models in enumerate(model_lists):
        merged_list.extend(
            [
                {**model, "urlIdx": idx}
                for model in models
                if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
                or "gpt" in model["id"]
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
168
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170


Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
178
async def get_all_models():
    print("get_all_models")
    tasks = [
        fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
        for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
    ]
    responses = await asyncio.gather(*tasks)
    responses = list(filter(lambda x: x is not None, responses))
Timothy J. Baek's avatar
Timothy J. Baek committed
179

Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182
183
184
185
    models = {
        "data": merge_models_lists(
            list(map(lambda response: response["data"], responses))
        )
    }
    app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
186

Timothy J. Baek's avatar
Timothy J. Baek committed
187
    return models
Timothy J. Baek's avatar
Timothy J. Baek committed
188

Timothy J. Baek's avatar
Timothy J. Baek committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

# , user=Depends(get_current_user)
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None):
    if url_idx == None:
        return await get_all_models()
    else:
        url = app.state.OPENAI_API_BASE_URLS[url_idx]
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225


Timothy J. Baek's avatar
Timothy J. Baek committed
226
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
227
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
228
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
229

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
230
231
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
232
233
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
234
235
236
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
237
238
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

239
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
240
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
241
        if body.get("model") == "gpt-4-vision-preview":
242
243
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
244
            print("Modified body_dict:", body)
245

Sakkus's avatar
Sakkus committed
246
        # Fix for ChatGPT calls failing because the num_ctx key is in body
247
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
248
249
250
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
251
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
252

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
253
254
255
256
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
        print("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
257

Timothy J. Baek's avatar
Timothy J. Baek committed
258
259
260
261
262
263
264
265
    url = app.state.OPENAI_API_BASE_URLS[idx]
    key = app.state.OPENAI_API_KEYS[idx]

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

266
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
267
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
268
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271
272
273

    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
274
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
277
278
279
280
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

281
282
283
284
285
286
287
288
289
290
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
    except Exception as e:
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
293
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
296
297
298
299
300
301
302
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)