main.py 11.2 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
24
25
26
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    ENABLE_MODEL_FILTER,
28
    MODEL_FILTER_LIST,
29
    MODEL_CONFIG,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

50
51
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
52
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
53
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
54
app.state.MODEL_CONFIG = MODEL_CONFIG.value.get("openai", [])
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
61

Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64
65
66
67
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
68

Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72


Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
75
76


Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
79
80


Timothy J. Baek's avatar
Timothy J. Baek committed
81
82
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
83
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
84

Timothy J. Baek's avatar
Timothy J. Baek committed
85

Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
88
    await get_all_models()
89
90
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
91
92


Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
95
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
98
99


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
100
101
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
102
103


Timothy J. Baek's avatar
Timothy J. Baek committed
104
@app.post("/audio/speech")
105
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
    idx = None
    try:
108
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
112
113
114
115
116
117
118
119
120
121
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
122
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
123
        headers["Content-Type"] = "application/json"
124
125
126
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
127
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
        try:
            r = requests.post(
130
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
131
132
133
134
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
135

Timothy J. Baek's avatar
Timothy J. Baek committed
136
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
137

Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
140
141
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
142

Timothy J. Baek's avatar
Timothy J. Baek committed
143
144
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
145

Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
148

Timothy J. Baek's avatar
Timothy J. Baek committed
149
        except Exception as e:
150
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
153
154
155
156
157
158
159
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
162
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
165

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167


Timothy J. Baek's avatar
Timothy J. Baek committed
168
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
169
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
172
173
174
175
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
176
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
180
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
181
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
184
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
188
189
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
190
                    if "api.openai.com"
191
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193
194
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
195

Timothy J. Baek's avatar
Timothy J. Baek committed
196
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
197
198


Timothy J. Baek's avatar
Timothy J. Baek committed
199
async def get_all_models():
200
    log.info("get_all_models()")
201

202
    if (
203
204
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
205
    ):
206
207
208
        models = {"data": []}
    else:
        tasks = [
209
210
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
211
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
212

213
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
        log.info(f"get_all_models:responses() {responses}")

216
217
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
220
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
221
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
222
223
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
224
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
                        responses,
                    )
                )
228
229
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
230

231
232
233
        for model in models["data"]:
            add_custom_info_to_model(model)

234
        log.info(f"models: {models}")
235
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
236

237
238
239
240
241
    return models


def add_custom_info_to_model(model: dict):
    model["custom_info"] = next(
242
        (item for item in app.state.MODEL_CONFIG if item["id"] == model["id"]), {}
243
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
244

Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
248
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
249
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
250
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
251
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
252
            if user.role == "user":
253
254
                models["data"] = list(
                    filter(
255
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
256
257
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
258
259
260
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
261
    else:
262
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
263
264
265

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
266
267
268
269
270
271
272
273
274
275
276
277
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
278
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
281
282
283
284
285
286
287
288
289
290
291
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
292
293


Timothy J. Baek's avatar
Timothy J. Baek committed
294
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
295
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
296
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
297

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
298
299
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
300
301
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
302
303
304
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
305
306
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

307
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
308
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
309
        if body.get("model") == "gpt-4-vision-preview":
310
311
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
312
            log.debug("Modified body_dict:", body)
313

Sakkus's avatar
Sakkus committed
314
        # Fix for ChatGPT calls failing because the num_ctx key is in body
315
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
316
317
318
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
319
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
320

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
321
322
323
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
324
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
325

326
327
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
328
329
330
331
332
333

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

334
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
335
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
336
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
337

Timothy J. Baek's avatar
Timothy J. Baek committed
338
339
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
340
341
342
343
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
344
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
345
346
347
348
349
350
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

351
352
353
354
355
356
357
358
359
360
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
361
    except Exception as e:
362
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
363
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
364
365
366
367
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
368
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
369
370
371
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
372
373
374
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )