main.py 9.03 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
Timothy J. Baek's avatar
Timothy J. Baek committed
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
11
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
17
18
19
20
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
24
25
26

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
29
30
31
32
33
34
35
36

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
40
41
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
42

Timothy J. Baek's avatar
Timothy J. Baek committed
43
44
45
46
47
48
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
49

Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
52
53


Timothy J. Baek's avatar
Timothy J. Baek committed
54
55
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57


Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
60
61


Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
65

Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72


Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77
78
79
80
81
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83


Timothy J. Baek's avatar
Timothy J. Baek committed
84
@app.post("/audio/speech")
85
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    idx = None
    try:
        idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
        headers["Content-Type"] = "application/json"

        try:
            r = requests.post(
                url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
114

Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
119

Timothy J. Baek's avatar
Timothy J. Baek committed
120
121
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
122

Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
125

Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(status_code=r.status_code, detail=error_detail)

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142


Timothy J. Baek's avatar
Timothy J. Baek committed
143
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
144
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
        return None


def merge_models_lists(model_lists):
    merged_list = []

    for idx, models in enumerate(model_lists):
        merged_list.extend(
            [
                {**model, "urlIdx": idx}
                for model in models
                if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
                or "gpt" in model["id"]
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
168
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170


Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
async def get_all_models():
    print("get_all_models")
    tasks = [
        fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
        for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
    ]
    responses = await asyncio.gather(*tasks)
178
    responses = list(filter(lambda x: x is not None and "error" not in x, responses))
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
182
183
184
    models = {
        "data": merge_models_lists(
            list(map(lambda response: response["data"], responses))
        )
    }
    app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
185

Timothy J. Baek's avatar
Timothy J. Baek committed
186
    return models
Timothy J. Baek's avatar
Timothy J. Baek committed
187

Timothy J. Baek's avatar
Timothy J. Baek committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

# , user=Depends(get_current_user)
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None):
    if url_idx == None:
        return await get_all_models()
    else:
        url = app.state.OPENAI_API_BASE_URLS[url_idx]
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
            print(e)
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
223
224


Timothy J. Baek's avatar
Timothy J. Baek committed
225
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
226
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
227
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
228

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
229
230
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
231
232
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
233
234
235
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
236
237
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

238
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
239
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
240
        if body.get("model") == "gpt-4-vision-preview":
241
242
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
243
            print("Modified body_dict:", body)
244

Sakkus's avatar
Sakkus committed
245
        # Fix for ChatGPT calls failing because the num_ctx key is in body
246
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
247
248
249
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
250
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
251

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
252
253
254
255
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
        print("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
256

Timothy J. Baek's avatar
Timothy J. Baek committed
257
258
259
260
261
262
263
264
    url = app.state.OPENAI_API_BASE_URLS[idx]
    key = app.state.OPENAI_API_KEYS[idx]

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

265
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
266
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
267
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269
270
271
272

    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
273
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
274
275
276
277
278
279
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

280
281
282
283
284
285
286
287
288
289
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
    except Exception as e:
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
292
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
293
294
295
296
297
298
299
300
301
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(status_code=r.status_code, detail=error_detail)