main.py 15.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

13
14
from apps.webui.models.models import Models
from apps.webui.models.users import Users
Timothy J. Baek's avatar
Timothy J. Baek committed
15
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
    ENABLE_OPENAI_API,
25
26
27
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
28
    ENABLE_MODEL_FILTER,
29
    MODEL_FILTER_LIST,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50

51
52
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
55
56

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
57
58
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
62

Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65
66
67
68
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
69

Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
72
73


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
91
92


Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
99
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
100

Timothy J. Baek's avatar
Timothy J. Baek committed
101

Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
104
    await get_all_models()
105
106
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
111
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113
114
115


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
116
117
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
118
119


Timothy J. Baek's avatar
Timothy J. Baek committed
120
@app.post("/audio/speech")
121
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
    idx = None
    try:
124
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
127
128
129
130
131
132
133
134
135
136
137
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
138
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
139
        headers["Content-Type"] = "application/json"
140
141
142
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
143
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
        try:
            r = requests.post(
146
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
153

Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
156
157
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
158

Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
161

Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
164

Timothy J. Baek's avatar
Timothy J. Baek committed
165
        except Exception as e:
166
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
169
170
171
172
173
174
175
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
178
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183


Timothy J. Baek's avatar
Timothy J. Baek committed
184
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
185
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
186
    try:
187
188
189
190
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession(timeout=timeout) as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
191
192
    except Exception as e:
        # Handle connection error here
193
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
196
197
        return None


def merge_models_lists(model_lists):
198
    log.debug(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
204
        if models is not None and "error" not in models:
            merged_list.extend(
                [
205
206
                    {
                        **model,
207
                        "name": model.get("name", model["id"]),
208
209
210
211
                        "owned_by": "openai",
                        "openai": model,
                        "urlIdx": idx,
                    }
Timothy J. Baek's avatar
Timothy J. Baek committed
212
                    for model in models
213
                    if "api.openai.com"
214
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
217
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
218

Timothy J. Baek's avatar
Timothy J. Baek committed
219
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
220
221


Timothy J. Baek's avatar
Timothy J. Baek committed
222
async def get_all_models(raw: bool = False):
223
    log.info("get_all_models()")
224

225
    if (
226
227
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
228
    ) or not app.state.config.ENABLE_OPENAI_API:
229
230
        models = {"data": []}
    else:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        # Check if API KEYS length is same than API URLS length
        if len(app.state.config.OPENAI_API_KEYS) != len(
            app.state.config.OPENAI_API_BASE_URLS
        ):
            # if there are more keys than urls, remove the extra keys
            if len(app.state.config.OPENAI_API_KEYS) > len(
                app.state.config.OPENAI_API_BASE_URLS
            ):
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
                    : len(app.state.config.OPENAI_API_BASE_URLS)
                ]
            # if there are more urls than keys, add empty keys
            else:
                app.state.config.OPENAI_API_KEYS += [
                    ""
                    for _ in range(
                        len(app.state.config.OPENAI_API_BASE_URLS)
                        - len(app.state.config.OPENAI_API_KEYS)
                    )
                ]

252
        tasks = [
253
254
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
255
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
256

257
        responses = await asyncio.gather(*tasks)
258
        log.debug(f"get_all_models:responses() {responses}")
Timothy J. Baek's avatar
Timothy J. Baek committed
259

Timothy J. Baek's avatar
Timothy J. Baek committed
260
261
262
        if raw:
            return responses

263
264
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
265
266
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
267
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
268
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
271
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
272
273
274
                        responses,
                    )
                )
275
276
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
277

278
        log.debug(f"models: {models}")
279
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
280

281
282
283
    return models


Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
286
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
287
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
288
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
289
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
290
            if user.role == "user":
291
292
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
293
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
294
295
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
298
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
299
    else:
300
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
301
302
303
304
305
        key = app.state.config.OPENAI_API_KEYS[url_idx]

        headers = {}
        headers["Authorization"] = f"Bearer {key}"
        headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
306
307
308

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
309
        try:
310
            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
313
314
315
316
317
318
319
320
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
321
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
324
325
326
327
328
329
330
331
332
333
334
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
335
336


Timothy J. Baek's avatar
Timothy J. Baek committed
337
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
338
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
339
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
340

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
341
342
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
343
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346

    payload = None

347
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
348
349
350
        if "chat/completions" in path:
            body = body.decode("utf-8")
            body = json.loads(body)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
351

Timothy J. Baek's avatar
Timothy J. Baek committed
352
            payload = {**body}
353

Timothy J. Baek's avatar
Timothy J. Baek committed
354
355
            model_id = body.get("model")
            model_info = Models.get_model_by_id(model_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
356

Timothy J. Baek's avatar
Timothy J. Baek committed
357
358
359
360
            if model_info:
                print(model_info)
                if model_info.base_model_id:
                    payload["model"] = model_info.base_model_id
Timothy J. Baek's avatar
Timothy J. Baek committed
361

Timothy J. Baek's avatar
Timothy J. Baek committed
362
363
364
                model_info.params = model_info.params.model_dump()

                if model_info.params:
Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                    if model_info.params.get("temperature", None):
                        payload["temperature"] = int(
                            model_info.params.get("temperature")
                        )

                    if model_info.params.get("top_p", None):
                        payload["top_p"] = int(model_info.params.get("top_p", None))

                    if model_info.params.get("max_tokens", None):
                        payload["max_tokens"] = int(
                            model_info.params.get("max_tokens", None)
                        )

                    if model_info.params.get("frequency_penalty", None):
                        payload["frequency_penalty"] = int(
                            model_info.params.get("frequency_penalty", None)
                        )

                    if model_info.params.get("seed", None):
                        payload["seed"] = model_info.params.get("seed", None)

                    if model_info.params.get("stop", None):
                        payload["stop"] = (
                            [
                                bytes(stop, "utf-8").decode("unicode_escape")
                                for stop in model_info.params["stop"]
                            ]
                            if model_info.params.get("stop", None)
                            else None
                        )
Timothy J. Baek's avatar
Timothy J. Baek committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

                if model_info.params.get("system", None):
                    # Check if the payload already has a system message
                    # If not, add a system message to the payload
                    if payload.get("messages"):
                        for message in payload["messages"]:
                            if message.get("role") == "system":
                                message["content"] = (
                                    model_info.params.get("system", None)
                                    + message["content"]
                                )
                                break
                        else:
                            payload["messages"].insert(
                                0,
                                {
                                    "role": "system",
                                    "content": model_info.params.get("system", None),
                                },
                            )
            else:
                pass

            model = app.state.MODELS[payload.get("model")]

            idx = model["urlIdx"]

            if "pipeline" in model and model.get("pipeline"):
                payload["user"] = {"name": user.name, "id": user.id}

            # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
            # This is a workaround until OpenAI fixes the issue with this model
            if payload.get("model") == "gpt-4-vision-preview":
                if "max_tokens" not in payload:
                    payload["max_tokens"] = 4000
                log.debug("Modified payload:", payload)

            # Convert the modified body back to JSON
            payload = json.dumps(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
434

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
435
    except json.JSONDecodeError as e:
436
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
437

Timothy J. Baek's avatar
Timothy J. Baek committed
438
439
    print(payload)

440
441
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
444

    target_url = f"{url}/{path}"

445
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
446
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
447
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
448

Timothy J. Baek's avatar
Timothy J. Baek committed
449
450
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
451
452
453
454
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
Timothy J. Baek committed
455
            data=payload if payload else body,
Timothy J. Baek's avatar
Timothy J. Baek committed
456
457
458
459
460
461
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

462
463
464
465
466
467
468
469
470
471
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
472
    except Exception as e:
473
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
474
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
475
476
477
        if r is not None:
            try:
                res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
478
                print(res)
Timothy J. Baek's avatar
Timothy J. Baek committed
479
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
480
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
481
482
483
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
484
485
486
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )