main.py 17.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
from pydantic import BaseModel
12
from starlette.background import BackgroundTask
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
from apps.webui.models.models import Models
from apps.webui.models.users import Users
Timothy J. Baek's avatar
Timothy J. Baek committed
16
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
19
20
21
22
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
23
24
from utils.task import prompt_template

25
from config import (
26
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
27
    ENABLE_OPENAI_API,
28
29
30
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    ENABLE_MODEL_FILTER,
32
    MODEL_FILTER_LIST,
33
    AppConfig,
34
)
Timothy J. Baek's avatar
Timothy J. Baek committed
35
36
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
40

41
42
43
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
44
45
46
47
48
49
50
51
52
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
53

54
55
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
58
59

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
60
61
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
62
63
64

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
65

Timothy J. Baek's avatar
Timothy J. Baek committed
66
67
68
69
70
71
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
72

Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
75
76


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
94
95


Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
98
99


Timothy J. Baek's avatar
Timothy J. Baek committed
100
101
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
102
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
103

Timothy J. Baek's avatar
Timothy J. Baek committed
104

Timothy J. Baek's avatar
Timothy J. Baek committed
105
106
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
107
    await get_all_models()
108
109
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
110
111


Timothy J. Baek's avatar
Timothy J. Baek committed
112
113
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
114
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
119
120
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
121
122


Timothy J. Baek's avatar
Timothy J. Baek committed
123
@app.post("/audio/speech")
124
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
    idx = None
    try:
127
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
131
132
133
134
135
136
137
138
139
140
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
141
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
142
        headers["Content-Type"] = "application/json"
143
144
145
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
146
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
        try:
            r = requests.post(
149
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
152
153
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
154

Timothy J. Baek's avatar
Timothy J. Baek committed
155
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
156

Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
159
160
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
161

Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
164

Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
167

Timothy J. Baek's avatar
Timothy J. Baek committed
168
        except Exception as e:
169
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
172
173
174
175
176
177
178
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
184

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186


Timothy J. Baek's avatar
Timothy J. Baek committed
187
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
188
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
189
    try:
190
        headers = {"Authorization": f"Bearer {key}"}
191
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
192
193
            async with session.get(url, headers=headers) as response:
                return await response.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
    except Exception as e:
        # Handle connection error here
196
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
199
        return None


200
201
202
203
204
205
206
207
208
209
async def cleanup_response(
    response: Optional[aiohttp.ClientResponse],
    session: Optional[aiohttp.ClientSession],
):
    if response:
        response.close()
    if session:
        await session.close()


Timothy J. Baek's avatar
Timothy J. Baek committed
210
def merge_models_lists(model_lists):
211
    log.debug(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
214
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
217
        if models is not None and "error" not in models:
            merged_list.extend(
                [
218
219
                    {
                        **model,
220
                        "name": model.get("name", model["id"]),
221
222
223
224
                        "owned_by": "openai",
                        "openai": model,
                        "urlIdx": idx,
                    }
Timothy J. Baek's avatar
Timothy J. Baek committed
225
                    for model in models
226
                    if "api.openai.com"
227
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
228
229
230
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
231

Timothy J. Baek's avatar
Timothy J. Baek committed
232
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
233
234


Timothy J. Baek's avatar
Timothy J. Baek committed
235
async def get_all_models(raw: bool = False):
236
    log.info("get_all_models()")
237

238
    if (
239
240
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
241
    ) or not app.state.config.ENABLE_OPENAI_API:
242
243
        models = {"data": []}
    else:
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # Check if API KEYS length is same than API URLS length
        if len(app.state.config.OPENAI_API_KEYS) != len(
            app.state.config.OPENAI_API_BASE_URLS
        ):
            # if there are more keys than urls, remove the extra keys
            if len(app.state.config.OPENAI_API_KEYS) > len(
                app.state.config.OPENAI_API_BASE_URLS
            ):
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
                    : len(app.state.config.OPENAI_API_BASE_URLS)
                ]
            # if there are more urls than keys, add empty keys
            else:
                app.state.config.OPENAI_API_KEYS += [
                    ""
                    for _ in range(
                        len(app.state.config.OPENAI_API_BASE_URLS)
                        - len(app.state.config.OPENAI_API_KEYS)
                    )
                ]

265
        tasks = [
266
267
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
268
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
269

270
        responses = await asyncio.gather(*tasks)
271
        log.debug(f"get_all_models:responses() {responses}")
Timothy J. Baek's avatar
Timothy J. Baek committed
272

Timothy J. Baek's avatar
Timothy J. Baek committed
273
274
275
        if raw:
            return responses

276
277
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
280
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
281
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
282
283
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
284
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
285
286
287
                        responses,
                    )
                )
288
289
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
290

291
        log.debug(f"models: {models}")
292
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
293

294
295
296
    return models


Timothy J. Baek's avatar
Timothy J. Baek committed
297
298
@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
299
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
300
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
301
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
302
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
303
            if user.role == "user":
304
305
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
306
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
307
308
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
309
310
311
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
312
    else:
313
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
314
315
316
317
318
        key = app.state.config.OPENAI_API_KEYS[url_idx]

        headers = {}
        headers["Authorization"] = f"Bearer {key}"
        headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
319
320
321

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
322
        try:
323
            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
324
325
326
327
328
329
330
331
332
333
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
334
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
335
336
337
338
339
340
341
342
343
344
345
346
347
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
348
349


Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
353
354
355
356
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
    form_data: dict,
    url_idx: Optional[int] = None,
    user=Depends(get_verified_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
357
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
358
    payload = {**form_data}
Timothy J. Baek's avatar
Timothy J. Baek committed
359

Timothy J. Baek's avatar
Timothy J. Baek committed
360
    model_id = form_data.get("model")
361
    model_info = Models.get_model_by_id(model_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
362

Timothy J. Baek's avatar
Timothy J. Baek committed
363
364
365
    if model_info:
        if model_info.base_model_id:
            payload["model"] = model_info.base_model_id
Timothy J. Baek's avatar
Timothy J. Baek committed
366

Timothy J. Baek's avatar
Timothy J. Baek committed
367
368
369
370
371
        model_info.params = model_info.params.model_dump()

        if model_info.params:
            if model_info.params.get("temperature", None) is not None:
                payload["temperature"] = float(model_info.params.get("temperature"))
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
372

Timothy J. Baek's avatar
Timothy J. Baek committed
373
374
            if model_info.params.get("top_p", None):
                payload["top_p"] = int(model_info.params.get("top_p", None))
375

Timothy J. Baek's avatar
Timothy J. Baek committed
376
377
            if model_info.params.get("max_tokens", None):
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
Timothy J. Baek's avatar
Timothy J. Baek committed
378

Timothy J. Baek's avatar
Timothy J. Baek committed
379
380
381
382
            if model_info.params.get("frequency_penalty", None):
                payload["frequency_penalty"] = int(
                    model_info.params.get("frequency_penalty", None)
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
383

Timothy J. Baek's avatar
Timothy J. Baek committed
384
385
386
387
388
389
390
391
392
393
394
395
            if model_info.params.get("seed", None):
                payload["seed"] = model_info.params.get("seed", None)

            if model_info.params.get("stop", None):
                payload["stop"] = (
                    [
                        bytes(stop, "utf-8").decode("unicode_escape")
                        for stop in model_info.params["stop"]
                    ]
                    if model_info.params.get("stop", None)
                    else None
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
396

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        system = model_info.params.get("system", None)
        if system:
            system = prompt_template(
                system,
                **(
                    {
                        "user_name": user.name,
                        "user_location": (
                            user.info.get("location") if user.info else None
                        ),
                    }
                    if user
                    else {}
                ),
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
412
413
414
415
416
            # Check if the payload already has a system message
            # If not, add a system message to the payload
            if payload.get("messages"):
                for message in payload["messages"]:
                    if message.get("role") == "system":
417
                        message["content"] = system + message["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
418
419
420
421
422
423
                        break
                else:
                    payload["messages"].insert(
                        0,
                        {
                            "role": "system",
424
                            "content": system,
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426
                        },
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
427

Timothy J. Baek's avatar
Timothy J. Baek committed
428
429
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
430

Timothy J. Baek's avatar
Timothy J. Baek committed
431
432
    model = app.state.MODELS[payload.get("model")]
    idx = model["urlIdx"]
Timothy J. Baek's avatar
Timothy J. Baek committed
433

Timothy J. Baek's avatar
Timothy J. Baek committed
434
    if "pipeline" in model and model.get("pipeline"):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
435
436
437
438
439
440
        payload["user"] = {
            "name": user.name,
            "id": user.id,
            "email": user.email,
            "role": user.role,
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
441

Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
444
445
446
447
    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
    # This is a workaround until OpenAI fixes the issue with this model
    if payload.get("model") == "gpt-4-vision-preview":
        if "max_tokens" not in payload:
            payload["max_tokens"] = 4000
        log.debug("Modified payload:", payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
448

Timothy J. Baek's avatar
Timothy J. Baek committed
449
450
    # Convert the modified body back to JSON
    payload = json.dumps(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
451

452
    log.debug(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
453

Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
456

Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
459
    headers = {}
    headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
460

Timothy J. Baek's avatar
Timothy J. Baek committed
461
462
463
464
465
466
467
468
469
470
471
472
    r = None
    session = None
    streaming = False

    try:
        session = aiohttp.ClientSession(trust_env=True)
        r = await session.request(
            method="POST",
            url=f"{url}/chat/completions",
            data=payload,
            headers=headers,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
473

Timothy J. Baek's avatar
Timothy J. Baek committed
474
        r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
475

Timothy J. Baek's avatar
Timothy J. Baek committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            streaming = True
            return StreamingResponse(
                r.content,
                status_code=r.status,
                headers=dict(r.headers),
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
            )
        else:
            response_data = await r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = await r.json()
                print(res)
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    idx = 0

    body = await request.body()
Timothy J. Baek's avatar
Timothy J. Baek committed
514

515
516
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
517
518
519

    target_url = f"{url}/{path}"

520
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
521
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
522
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
523

Timothy J. Baek's avatar
Timothy J. Baek committed
524
    r = None
525
526
    session = None
    streaming = False
Timothy J. Baek's avatar
Timothy J. Baek committed
527

Timothy J. Baek's avatar
Timothy J. Baek committed
528
    try:
529
        session = aiohttp.ClientSession(trust_env=True)
530
        r = await session.request(
Jun Siang Cheah's avatar
Jun Siang Cheah committed
531
532
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
Timothy J. Baek committed
533
            data=body,
Jun Siang Cheah's avatar
Jun Siang Cheah committed
534
            headers=headers,
Timothy J. Baek's avatar
Timothy J. Baek committed
535
536
537
538
        )

        r.raise_for_status()

539
540
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
541
            streaming = True
542
            return StreamingResponse(
543
544
                r.content,
                status_code=r.status,
545
                headers=dict(r.headers),
546
547
548
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
549
550
            )
        else:
551
            response_data = await r.json()
552
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
553
    except Exception as e:
554
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
555
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
556
557
        if r is not None:
            try:
558
                res = await r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
559
                print(res)
Timothy J. Baek's avatar
Timothy J. Baek committed
560
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
561
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
562
563
            except:
                error_detail = f"External: {e}"
564
565
566
567
568
569
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()