main.py 10.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
24
25
26
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    ENABLE_MODEL_FILTER,
28
29
    MODEL_FILTER_LIST,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
30
31
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
34

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
35

36
37
38
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
39
40
41
42
43
44
45
46
47
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
48
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
49
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
50

Timothy J. Baek's avatar
Timothy J. Baek committed
51
52
53
54
55
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
56

Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
62
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
63

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
66
67


Timothy J. Baek's avatar
Timothy J. Baek committed
68
69
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71


Timothy J. Baek's avatar
Timothy J. Baek committed
72
73
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
74
75


Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
79

Timothy J. Baek's avatar
Timothy J. Baek committed
80

Timothy J. Baek's avatar
Timothy J. Baek committed
81
82
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
83
    await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
84
85
    app.state.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
86
87


Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
90
91
92
93
94
95
96
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
    app.state.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
97
98


Timothy J. Baek's avatar
Timothy J. Baek committed
99
@app.post("/audio/speech")
100
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    idx = None
    try:
        idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
        headers["Content-Type"] = "application/json"

Timothy J. Baek's avatar
Timothy J. Baek committed
120
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
121
122
123
124
125
126
127
        try:
            r = requests.post(
                url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
128

Timothy J. Baek's avatar
Timothy J. Baek committed
129
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
130

Timothy J. Baek's avatar
Timothy J. Baek committed
131
132
133
134
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
135

Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
138

Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
141

Timothy J. Baek's avatar
Timothy J. Baek committed
142
        except Exception as e:
143
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
146
147
148
149
150
151
152
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
153
154
155
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160


Timothy J. Baek's avatar
Timothy J. Baek committed
161
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
162
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
165
166
167
168
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
169
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
172
173
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
174
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
177
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
181
182
183
184
185
186
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
                    if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
187

Timothy J. Baek's avatar
Timothy J. Baek committed
188
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
189
190


Timothy J. Baek's avatar
Timothy J. Baek committed
191
async def get_all_models():
192
    log.info("get_all_models()")
193
194
195
196
197
198
199
200

    if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
        models = {"data": []}
    else:
        tasks = [
            fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
201

202
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
        log.info(f"get_all_models:responses() {responses}")

205
206
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
207
208
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
209
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
210
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
213
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
216
                        responses,
                    )
                )
217
218
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
219

220
        log.info(f"models: {models}")
221
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
222

223
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
224

Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
228
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
229
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
230
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
231
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
232
            if user.role == "user":
233
234
                models["data"] = list(
                    filter(
235
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
236
237
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
238
239
240
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
241
242
    else:
        url = app.state.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
243
244
245

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
248
249
250
251
252
253
254
255
256
257
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
258
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
259
260
261
262
263
264
265
266
267
268
269
270
271
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
272
273


Timothy J. Baek's avatar
Timothy J. Baek committed
274
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
275
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
276
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
277

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
278
279
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
280
281
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
282
283
284
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
285
286
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

287
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
288
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
289
        if body.get("model") == "gpt-4-vision-preview":
290
291
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
292
            log.debug("Modified body_dict:", body)
293

Sakkus's avatar
Sakkus committed
294
        # Fix for ChatGPT calls failing because the num_ctx key is in body
295
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
296
297
298
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
299
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
300

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
301
302
303
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
304
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
305

Timothy J. Baek's avatar
Timothy J. Baek committed
306
307
308
309
310
311
312
313
    url = app.state.OPENAI_API_BASE_URLS[idx]
    key = app.state.OPENAI_API_KEYS[idx]

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

314
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
315
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
316
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
317

Timothy J. Baek's avatar
Timothy J. Baek committed
318
319
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
322
323
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
324
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
327
328
329
330
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

331
332
333
334
335
336
337
338
339
340
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
341
    except Exception as e:
342
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
343
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
347
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
348
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350
351
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
354
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )