"test/srt/test_mla_tp.py" did not exist on "4af3f889fc6f406c0fc3b7a310e3ad7220b01ff6"
main.py 17.1 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
from pydantic import BaseModel
12
from starlette.background import BackgroundTask
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
from apps.webui.models.models import Models
from apps.webui.models.users import Users
Timothy J. Baek's avatar
Timothy J. Baek committed
16
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
19
20
21
22
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
23
from config import (
24
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
25
    ENABLE_OPENAI_API,
26
27
28
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
29
    ENABLE_MODEL_FILTER,
30
    MODEL_FILTER_LIST,
31
    AppConfig,
32
)
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
35
36
37

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
38

39
40
41
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
42
43
44
45
46
47
48
49
50
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
51

52
53
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
54
55
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
56
57

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
58
59
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
60
61
62

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
63

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
69
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
70

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
73
74


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93


Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97


Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
100
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
101

Timothy J. Baek's avatar
Timothy J. Baek committed
102

Timothy J. Baek's avatar
Timothy J. Baek committed
103
104
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
105
    await get_all_models()
106
107
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109


Timothy J. Baek's avatar
Timothy J. Baek committed
110
111
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
112
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
117
118
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120


Timothy J. Baek's avatar
Timothy J. Baek committed
121
@app.post("/audio/speech")
122
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
    idx = None
    try:
125
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
136
137
138
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
139
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
140
        headers["Content-Type"] = "application/json"
141
142
143
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
144
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
        try:
            r = requests.post(
147
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149
150
151
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
152

Timothy J. Baek's avatar
Timothy J. Baek committed
153
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
154

Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
157
158
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
159

Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
162

Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
165

Timothy J. Baek's avatar
Timothy J. Baek committed
166
        except Exception as e:
167
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
170
171
172
173
174
175
176
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184


Timothy J. Baek's avatar
Timothy J. Baek committed
185
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
186
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    try:
188
        headers = {"Authorization": f"Bearer {key}"}
189
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
190
191
            async with session.get(url, headers=headers) as response:
                return await response.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193
    except Exception as e:
        # Handle connection error here
194
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196
197
        return None


198
199
200
201
202
203
204
205
206
207
async def cleanup_response(
    response: Optional[aiohttp.ClientResponse],
    session: Optional[aiohttp.ClientSession],
):
    if response:
        response.close()
    if session:
        await session.close()


Timothy J. Baek's avatar
Timothy J. Baek committed
208
def merge_models_lists(model_lists):
209
    log.debug(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
212
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
        if models is not None and "error" not in models:
            merged_list.extend(
                [
216
217
                    {
                        **model,
218
                        "name": model.get("name", model["id"]),
219
220
221
222
                        "owned_by": "openai",
                        "openai": model,
                        "urlIdx": idx,
                    }
Timothy J. Baek's avatar
Timothy J. Baek committed
223
                    for model in models
224
                    if "api.openai.com"
225
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
226
227
228
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
229

Timothy J. Baek's avatar
Timothy J. Baek committed
230
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
231
232


Timothy J. Baek's avatar
Timothy J. Baek committed
233
async def get_all_models(raw: bool = False):
234
    log.info("get_all_models()")
235

236
    if (
237
238
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
239
    ) or not app.state.config.ENABLE_OPENAI_API:
240
241
        models = {"data": []}
    else:
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        # Check if API KEYS length is same than API URLS length
        if len(app.state.config.OPENAI_API_KEYS) != len(
            app.state.config.OPENAI_API_BASE_URLS
        ):
            # if there are more keys than urls, remove the extra keys
            if len(app.state.config.OPENAI_API_KEYS) > len(
                app.state.config.OPENAI_API_BASE_URLS
            ):
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
                    : len(app.state.config.OPENAI_API_BASE_URLS)
                ]
            # if there are more urls than keys, add empty keys
            else:
                app.state.config.OPENAI_API_KEYS += [
                    ""
                    for _ in range(
                        len(app.state.config.OPENAI_API_BASE_URLS)
                        - len(app.state.config.OPENAI_API_KEYS)
                    )
                ]

263
        tasks = [
264
265
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
266
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
267

268
        responses = await asyncio.gather(*tasks)
269
        log.debug(f"get_all_models:responses() {responses}")
Timothy J. Baek's avatar
Timothy J. Baek committed
270

Timothy J. Baek's avatar
Timothy J. Baek committed
271
272
273
        if raw:
            return responses

274
275
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
276
277
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
278
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
279
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
282
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
285
                        responses,
                    )
                )
286
287
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
288

289
        log.debug(f"models: {models}")
290
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
291

292
293
294
    return models


Timothy J. Baek's avatar
Timothy J. Baek committed
295
296
@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
297
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
298
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
299
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
300
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
301
            if user.role == "user":
302
303
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
304
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
305
306
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
307
308
309
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
310
    else:
311
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
312
313
314
315
316
        key = app.state.config.OPENAI_API_KEYS[url_idx]

        headers = {}
        headers["Authorization"] = f"Bearer {key}"
        headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
320
        try:
321
            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
324
325
326
327
328
329
330
331
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
332
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
333
334
335
336
337
338
339
340
341
342
343
344
345
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
346
347


Timothy J. Baek's avatar
Timothy J. Baek committed
348
349
350
351
352
353
354
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
    form_data: dict,
    url_idx: Optional[int] = None,
    user=Depends(get_verified_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
355
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
356
    payload = {**form_data}
Timothy J. Baek's avatar
Timothy J. Baek committed
357

Timothy J. Baek's avatar
Timothy J. Baek committed
358
359
    model_id = form_data.get("model")
    model_info = Models.get_model_by_id(model_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
360

Timothy J. Baek's avatar
Timothy J. Baek committed
361
362
363
    if model_info:
        if model_info.base_model_id:
            payload["model"] = model_info.base_model_id
Timothy J. Baek's avatar
Timothy J. Baek committed
364

Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
367
368
369
        model_info.params = model_info.params.model_dump()

        if model_info.params:
            if model_info.params.get("temperature", None) is not None:
                payload["temperature"] = float(model_info.params.get("temperature"))
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
370

Timothy J. Baek's avatar
Timothy J. Baek committed
371
372
            if model_info.params.get("top_p", None):
                payload["top_p"] = int(model_info.params.get("top_p", None))
373

Timothy J. Baek's avatar
Timothy J. Baek committed
374
375
            if model_info.params.get("max_tokens", None):
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
Timothy J. Baek's avatar
Timothy J. Baek committed
376

Timothy J. Baek's avatar
Timothy J. Baek committed
377
378
379
380
            if model_info.params.get("frequency_penalty", None):
                payload["frequency_penalty"] = int(
                    model_info.params.get("frequency_penalty", None)
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
381

Timothy J. Baek's avatar
Timothy J. Baek committed
382
383
384
385
386
387
388
389
390
391
392
393
            if model_info.params.get("seed", None):
                payload["seed"] = model_info.params.get("seed", None)

            if model_info.params.get("stop", None):
                payload["stop"] = (
                    [
                        bytes(stop, "utf-8").decode("unicode_escape")
                        for stop in model_info.params["stop"]
                    ]
                    if model_info.params.get("stop", None)
                    else None
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
394

Timothy J. Baek's avatar
Timothy J. Baek committed
395
396
397
398
399
400
401
402
        if model_info.params.get("system", None):
            # Check if the payload already has a system message
            # If not, add a system message to the payload
            if payload.get("messages"):
                for message in payload["messages"]:
                    if message.get("role") == "system":
                        message["content"] = (
                            model_info.params.get("system", None) + message["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
403
                        )
Timothy J. Baek's avatar
Timothy J. Baek committed
404
405
406
407
408
409
410
411
412
                        break
                else:
                    payload["messages"].insert(
                        0,
                        {
                            "role": "system",
                            "content": model_info.params.get("system", None),
                        },
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
413

Timothy J. Baek's avatar
Timothy J. Baek committed
414
415
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
416

Timothy J. Baek's avatar
Timothy J. Baek committed
417
418
    model = app.state.MODELS[payload.get("model")]
    idx = model["urlIdx"]
Timothy J. Baek's avatar
Timothy J. Baek committed
419

Timothy J. Baek's avatar
Timothy J. Baek committed
420
421
    if "pipeline" in model and model.get("pipeline"):
        payload["user"] = {"name": user.name, "id": user.id}
Timothy J. Baek's avatar
Timothy J. Baek committed
422

Timothy J. Baek's avatar
Timothy J. Baek committed
423
424
425
426
427
428
    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
    # This is a workaround until OpenAI fixes the issue with this model
    if payload.get("model") == "gpt-4-vision-preview":
        if "max_tokens" not in payload:
            payload["max_tokens"] = 4000
        log.debug("Modified payload:", payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
429

Timothy J. Baek's avatar
Timothy J. Baek committed
430
431
    # Convert the modified body back to JSON
    payload = json.dumps(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
432

Timothy J. Baek's avatar
Timothy J. Baek committed
433
    print(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
434

Timothy J. Baek's avatar
Timothy J. Baek committed
435
436
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
437

Timothy J. Baek's avatar
Timothy J. Baek committed
438
439
440
    headers = {}
    headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
441

Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
444
445
446
447
448
449
450
451
452
453
    r = None
    session = None
    streaming = False

    try:
        session = aiohttp.ClientSession(trust_env=True)
        r = await session.request(
            method="POST",
            url=f"{url}/chat/completions",
            data=payload,
            headers=headers,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
454

Timothy J. Baek's avatar
Timothy J. Baek committed
455
        r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
456

Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            streaming = True
            return StreamingResponse(
                r.content,
                status_code=r.status,
                headers=dict(r.headers),
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
            )
        else:
            response_data = await r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = await r.json()
                print(res)
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    idx = 0

    body = await request.body()
Timothy J. Baek's avatar
Timothy J. Baek committed
495

496
497
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
498
499
500

    target_url = f"{url}/{path}"

501
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
502
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
503
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
504

Timothy J. Baek's avatar
Timothy J. Baek committed
505
    r = None
506
507
    session = None
    streaming = False
Timothy J. Baek's avatar
Timothy J. Baek committed
508

Timothy J. Baek's avatar
Timothy J. Baek committed
509
    try:
510
        session = aiohttp.ClientSession(trust_env=True)
511
        r = await session.request(
Jun Siang Cheah's avatar
Jun Siang Cheah committed
512
513
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
Timothy J. Baek committed
514
            data=body,
Jun Siang Cheah's avatar
Jun Siang Cheah committed
515
            headers=headers,
Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
519
        )

        r.raise_for_status()

520
521
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
522
            streaming = True
523
            return StreamingResponse(
524
525
                r.content,
                status_code=r.status,
526
                headers=dict(r.headers),
527
528
529
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
530
531
            )
        else:
532
            response_data = await r.json()
533
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
534
    except Exception as e:
535
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
536
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
537
538
        if r is not None:
            try:
539
                res = await r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
540
                print(res)
Timothy J. Baek's avatar
Timothy J. Baek committed
541
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
542
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
543
544
            except:
                error_detail = f"External: {e}"
545
546
547
548
549
550
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()