main.py 10.1 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
Timothy J. Baek's avatar
Timothy J. Baek committed
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
11
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
17
18
19
20
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
21
22
23
24
25
26
27
from config import (
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
    MODEL_FILTER_ENABLED,
    MODEL_FILTER_LIST,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
30
31
32

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
35
36
37
38
39
40
41
42

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

43
44
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
45

Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
51

Timothy J. Baek's avatar
Timothy J. Baek committed
52
53
54
55
56
57
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
58

Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62


Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
65
66


Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
69
70


Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
74

Timothy J. Baek's avatar
Timothy J. Baek committed
75

Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78
79
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81


Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
85
86
87
88
89
90
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
91
92


Timothy J. Baek's avatar
Timothy J. Baek committed
93
@app.post("/audio/speech")
94
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    idx = None
    try:
        idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
        headers["Content-Type"] = "application/json"

Timothy J. Baek's avatar
Timothy J. Baek committed
114
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
119
120
121
        try:
            r = requests.post(
                url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
122

Timothy J. Baek's avatar
Timothy J. Baek committed
123
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
124

Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
127
128
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
129

Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
132

Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
135

Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
138
139
140
141
142
143
144
145
146
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
152

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
153
154


Timothy J. Baek's avatar
Timothy J. Baek committed
155
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
156
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
        return None


def merge_models_lists(model_lists):
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
178
179
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
                    if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
180

Timothy J. Baek's avatar
Timothy J. Baek committed
181
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183


Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
async def get_all_models():
    print("get_all_models")
186
187
188
189
190
191
192
193

    if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
        models = {"data": []}
    else:
        tasks = [
            fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
194

195
196
197
        responses = await asyncio.gather(*tasks)
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
200
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
201
202
203
                            response["data"]
                            if response and "data" in response
                            else None
Timothy J. Baek's avatar
Timothy J. Baek committed
204
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
205
206
207
                        responses,
                    )
                )
208
209
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
210
211

        print(models)
212
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
213

214
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
215

Timothy J. Baek's avatar
Timothy J. Baek committed
216
217
218

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
219
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
220
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222
223
        models = await get_all_models()
        if app.state.MODEL_FILTER_ENABLED:
            if user.role == "user":
224
225
                models["data"] = list(
                    filter(
226
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
227
228
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
229
230
231
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
    else:
        url = app.state.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
234
235
236

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
263
264


Timothy J. Baek's avatar
Timothy J. Baek committed
265
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
266
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
267
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
268

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
269
270
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
271
272
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
273
274
275
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
276
277
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

278
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
279
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
280
        if body.get("model") == "gpt-4-vision-preview":
281
282
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
283
            print("Modified body_dict:", body)
284

Sakkus's avatar
Sakkus committed
285
        # Fix for ChatGPT calls failing because the num_ctx key is in body
286
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
287
288
289
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
290
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
291

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
292
293
294
295
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
        print("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
296

Timothy J. Baek's avatar
Timothy J. Baek committed
297
298
299
300
301
302
303
304
    url = app.state.OPENAI_API_BASE_URLS[idx]
    key = app.state.OPENAI_API_KEYS[idx]

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

305
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
306
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
307
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
308

Timothy J. Baek's avatar
Timothy J. Baek committed
309
310
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
313
314
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
315
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
316
317
318
319
320
321
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

322
323
324
325
326
327
328
329
330
331
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
    except Exception as e:
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
334
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
335
336
337
338
339
340
341
342
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
343
344
345
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )