main.py 9.44 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
Timothy J. Baek's avatar
Timothy J. Baek committed
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
11
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
17
18
19
20
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
24
25
26

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
29
30
31
32
33
34
35
36

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
44
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
45

Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
52

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
55
56


Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60


Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
63
64


Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
68

Timothy J. Baek's avatar
Timothy J. Baek committed
69

Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
74
75


Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78
79
80
81
82
83
84
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
85
86


Timothy J. Baek's avatar
Timothy J. Baek committed
87
@app.post("/audio/speech")
88
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    idx = None
    try:
        idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
        headers["Content-Type"] = "application/json"

        try:
            r = requests.post(
                url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
115

Timothy J. Baek's avatar
Timothy J. Baek committed
116
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
117

Timothy J. Baek's avatar
Timothy J. Baek committed
118
119
120
121
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
122

Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
125

Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
128

Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(status_code=r.status_code, detail=error_detail)

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
144
145


Timothy J. Baek's avatar
Timothy J. Baek committed
146
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
147
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
        return None


def merge_models_lists(model_lists):
    merged_list = []

    for idx, models in enumerate(model_lists):
        merged_list.extend(
            [
                {**model, "urlIdx": idx}
                for model in models
                if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
                or "gpt" in model["id"]
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
171
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173


Timothy J. Baek's avatar
Timothy J. Baek committed
174
175
176
177
178
179
180
async def get_all_models():
    print("get_all_models")
    tasks = [
        fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
        for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
    ]
    responses = await asyncio.gather(*tasks)
181
    responses = list(filter(lambda x: x is not None and "error" not in x, responses))
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
184
185
186
187
    models = {
        "data": merge_models_lists(
            list(map(lambda response: response["data"], responses))
        )
    }
    app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
188

Timothy J. Baek's avatar
Timothy J. Baek committed
189
    return models
Timothy J. Baek's avatar
Timothy J. Baek committed
190

Timothy J. Baek's avatar
Timothy J. Baek committed
191
192
193

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
194
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
195
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
196
197
198
        models = await get_all_models()
        if app.state.MODEL_FILTER_ENABLED:
            if user.role == "user":
199
200
201
202
203
                models["data"] = list(
                    filter(
                        lambda model: model["id"] in app.state.MODEL_LIST,
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
206
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    else:
        url = app.state.OPENAI_API_BASE_URLS[url_idx]
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
235
236


Timothy J. Baek's avatar
Timothy J. Baek committed
237
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
238
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
239
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
240

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
241
242
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
243
244
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
245
246
247
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
248
249
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

250
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
251
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
252
        if body.get("model") == "gpt-4-vision-preview":
253
254
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
255
            print("Modified body_dict:", body)
256

Sakkus's avatar
Sakkus committed
257
        # Fix for ChatGPT calls failing because the num_ctx key is in body
258
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
259
260
261
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
262
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
263

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
264
265
266
267
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
        print("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
268

Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271
272
273
274
275
276
    url = app.state.OPENAI_API_BASE_URLS[idx]
    key = app.state.OPENAI_API_KEYS[idx]

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

277
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
278
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
279
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
282
283
284

    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
286
287
288
289
290
291
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

292
293
294
295
296
297
298
299
300
301
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
    except Exception as e:
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
304
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
305
306
307
308
309
310
311
312
313
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)