main.py 17.9 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
from pydantic import BaseModel
12
from starlette.background import BackgroundTask
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
from apps.webui.models.models import Models
from apps.webui.models.users import Users
Timothy J. Baek's avatar
Timothy J. Baek committed
16
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
from utils.utils import (
    decode_token,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
19
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
20
21
22
    get_verified_user,
    get_admin_user,
)
23
from utils.task import prompt_template
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
from utils.misc import add_or_update_system_message
25

26
from config import (
27
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
28
    ENABLE_OPENAI_API,
29
    AIOHTTP_CLIENT_TIMEOUT,
30
31
32
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
    ENABLE_MODEL_FILTER,
34
    MODEL_FILTER_LIST,
35
    AppConfig,
36
)
Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
39
40
41

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
42

43
44
45
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
60
61

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
62
63
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
67

Timothy J. Baek's avatar
Timothy J. Baek committed
68
69
70
71
72
73
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
74

Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97


Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
100
101


Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
104
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
105

Timothy J. Baek's avatar
Timothy J. Baek committed
106

Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
109
    await get_all_models()
110
111
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113


Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
116
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
117
118
119
120


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
121
122
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
123
124


Timothy J. Baek's avatar
Timothy J. Baek committed
125
@app.post("/audio/speech")
126
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
    idx = None
    try:
129
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
132
133
134
135
136
137
138
139
140
141
142
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
143
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
144
        headers["Content-Type"] = "application/json"
145
146
147
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
148
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
        try:
            r = requests.post(
151
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
156

Timothy J. Baek's avatar
Timothy J. Baek committed
157
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
158

Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
161
162
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
163

Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
166

Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
169

Timothy J. Baek's avatar
Timothy J. Baek committed
170
        except Exception as e:
171
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
174
175
176
177
178
179
180
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
183
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
186

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188


Timothy J. Baek's avatar
Timothy J. Baek committed
189
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
190
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
191
    try:
192
        headers = {"Authorization": f"Bearer {key}"}
193
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
194
195
            async with session.get(url, headers=headers) as response:
                return await response.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
196
197
    except Exception as e:
        # Handle connection error here
198
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
        return None


202
203
204
205
206
207
208
209
210
211
async def cleanup_response(
    response: Optional[aiohttp.ClientResponse],
    session: Optional[aiohttp.ClientSession],
):
    if response:
        response.close()
    if session:
        await session.close()


Timothy J. Baek's avatar
Timothy J. Baek committed
212
def merge_models_lists(model_lists):
213
    log.debug(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
216
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
219
        if models is not None and "error" not in models:
            merged_list.extend(
                [
220
221
                    {
                        **model,
222
                        "name": model.get("name", model["id"]),
223
224
225
226
                        "owned_by": "openai",
                        "openai": model,
                        "urlIdx": idx,
                    }
Timothy J. Baek's avatar
Timothy J. Baek committed
227
                    for model in models
228
                    if "api.openai.com"
229
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
230
231
232
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
233

Timothy J. Baek's avatar
Timothy J. Baek committed
234
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
235
236


Timothy J. Baek's avatar
Timothy J. Baek committed
237
async def get_all_models(raw: bool = False):
238
    log.info("get_all_models()")
239

240
    if (
241
242
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
243
    ) or not app.state.config.ENABLE_OPENAI_API:
244
245
        models = {"data": []}
    else:
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        # Check if API KEYS length is same than API URLS length
        if len(app.state.config.OPENAI_API_KEYS) != len(
            app.state.config.OPENAI_API_BASE_URLS
        ):
            # if there are more keys than urls, remove the extra keys
            if len(app.state.config.OPENAI_API_KEYS) > len(
                app.state.config.OPENAI_API_BASE_URLS
            ):
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
                    : len(app.state.config.OPENAI_API_BASE_URLS)
                ]
            # if there are more urls than keys, add empty keys
            else:
                app.state.config.OPENAI_API_KEYS += [
                    ""
                    for _ in range(
                        len(app.state.config.OPENAI_API_BASE_URLS)
                        - len(app.state.config.OPENAI_API_KEYS)
                    )
                ]

267
        tasks = [
268
269
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
270
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
271

272
        responses = await asyncio.gather(*tasks)
273
        log.debug(f"get_all_models:responses() {responses}")
Timothy J. Baek's avatar
Timothy J. Baek committed
274

Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
277
        if raw:
            return responses

278
279
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
282
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
283
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
286
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
287
288
289
                        responses,
                    )
                )
290
291
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
292

293
        log.debug(f"models: {models}")
294
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
295

296
297
298
    return models


Timothy J. Baek's avatar
Timothy J. Baek committed
299
300
@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
301
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
302
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
303
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
304
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
305
            if user.role == "user":
306
307
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
308
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
309
310
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
313
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
314
    else:
315
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
316
317
318
319
320
        key = app.state.config.OPENAI_API_KEYS[url_idx]

        headers = {}
        headers["Authorization"] = f"Bearer {key}"
        headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
321
322
323

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
324
        try:
325
            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
326
327
328
329
330
331
332
333
334
335
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
336
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
337
338
339
340
341
342
343
344
345
346
347
348
349
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351


Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
354
355
356
357
358
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
    form_data: dict,
    url_idx: Optional[int] = None,
    user=Depends(get_verified_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
359
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
360
    payload = {**form_data}
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
361
362
    if "metadata" in payload:
        del payload["metadata"]
Timothy J. Baek's avatar
Timothy J. Baek committed
363

Timothy J. Baek's avatar
Timothy J. Baek committed
364
365
    model_id = form_data.get("model")
    model_info = Models.get_model_by_id(model_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
366

Timothy J. Baek's avatar
Timothy J. Baek committed
367
368
369
    if model_info:
        if model_info.base_model_id:
            payload["model"] = model_info.base_model_id
Timothy J. Baek's avatar
Timothy J. Baek committed
370

Timothy J. Baek's avatar
Timothy J. Baek committed
371
372
373
        model_info.params = model_info.params.model_dump()

        if model_info.params:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
374
            if (
375
                model_info.params.get("temperature", None) is not None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
376
377
                and payload.get("temperature") is None
            ):
Timothy J. Baek's avatar
Timothy J. Baek committed
378
                payload["temperature"] = float(model_info.params.get("temperature"))
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
379

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
380
            if model_info.params.get("top_p", None) and payload.get("top_p") is None:
Timothy J. Baek's avatar
Timothy J. Baek committed
381
                payload["top_p"] = int(model_info.params.get("top_p", None))
382

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
383
384
385
386
            if (
                model_info.params.get("max_tokens", None)
                and payload.get("max_tokens") is None
            ):
Timothy J. Baek's avatar
Timothy J. Baek committed
387
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
Timothy J. Baek's avatar
Timothy J. Baek committed
388

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
389
390
391
392
            if (
                model_info.params.get("frequency_penalty", None)
                and payload.get("frequency_penalty") is None
            ):
Timothy J. Baek's avatar
Timothy J. Baek committed
393
394
395
                payload["frequency_penalty"] = int(
                    model_info.params.get("frequency_penalty", None)
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
396

397
398
399
400
            if (
                model_info.params.get("seed", None) is not None
                and payload.get("seed") is None
            ):
Timothy J. Baek's avatar
Timothy J. Baek committed
401
402
                payload["seed"] = model_info.params.get("seed", None)

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
403
            if model_info.params.get("stop", None) and payload.get("stop") is None:
Timothy J. Baek's avatar
Timothy J. Baek committed
404
405
406
407
408
409
410
411
                payload["stop"] = (
                    [
                        bytes(stop, "utf-8").decode("unicode_escape")
                        for stop in model_info.params["stop"]
                    ]
                    if model_info.params.get("stop", None)
                    else None
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        system = model_info.params.get("system", None)
        if system:
            system = prompt_template(
                system,
                **(
                    {
                        "user_name": user.name,
                        "user_location": (
                            user.info.get("location") if user.info else None
                        ),
                    }
                    if user
                    else {}
                ),
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
428
            if payload.get("messages"):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
429
430
431
                payload["messages"] = add_or_update_system_message(
                    system, payload["messages"]
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
432

Timothy J. Baek's avatar
Timothy J. Baek committed
433
434
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
435

Timothy J. Baek's avatar
Timothy J. Baek committed
436
437
    model = app.state.MODELS[payload.get("model")]
    idx = model["urlIdx"]
Timothy J. Baek's avatar
Timothy J. Baek committed
438

Timothy J. Baek's avatar
Timothy J. Baek committed
439
    if "pipeline" in model and model.get("pipeline"):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
440
441
442
443
444
445
        payload["user"] = {
            "name": user.name,
            "id": user.id,
            "email": user.email,
            "role": user.role,
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
446

Timothy J. Baek's avatar
Timothy J. Baek committed
447
448
449
450
451
452
    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
    # This is a workaround until OpenAI fixes the issue with this model
    if payload.get("model") == "gpt-4-vision-preview":
        if "max_tokens" not in payload:
            payload["max_tokens"] = 4000
        log.debug("Modified payload:", payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
453

Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
    # Convert the modified body back to JSON
    payload = json.dumps(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
456

457
    log.debug(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
458

Timothy J. Baek's avatar
Timothy J. Baek committed
459
460
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
461

Timothy J. Baek's avatar
Timothy J. Baek committed
462
463
464
    headers = {}
    headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
465
466
467
    if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
        headers["HTTP-Referer"] = "https://openwebui.com/"
        headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
468

Timothy J. Baek's avatar
Timothy J. Baek committed
469
470
471
472
473
    r = None
    session = None
    streaming = False

    try:
474
475
476
        session = aiohttp.ClientSession(
            trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
477
478
479
480
481
482
        r = await session.request(
            method="POST",
            url=f"{url}/chat/completions",
            data=payload,
            headers=headers,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
483

Timothy J. Baek's avatar
Timothy J. Baek committed
484
        r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
485

Timothy J. Baek's avatar
Timothy J. Baek committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            streaming = True
            return StreamingResponse(
                r.content,
                status_code=r.status,
                headers=dict(r.headers),
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
            )
        else:
            response_data = await r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = await r.json()
                print(res)
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    idx = 0

    body = await request.body()
Timothy J. Baek's avatar
Timothy J. Baek committed
524

525
526
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
527
528
529

    target_url = f"{url}/{path}"

530
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
531
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
532
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
533

Timothy J. Baek's avatar
Timothy J. Baek committed
534
    r = None
535
536
    session = None
    streaming = False
Timothy J. Baek's avatar
Timothy J. Baek committed
537

Timothy J. Baek's avatar
Timothy J. Baek committed
538
    try:
539
        session = aiohttp.ClientSession(trust_env=True)
540
        r = await session.request(
Jun Siang Cheah's avatar
Jun Siang Cheah committed
541
542
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
Timothy J. Baek committed
543
            data=body,
Jun Siang Cheah's avatar
Jun Siang Cheah committed
544
            headers=headers,
Timothy J. Baek's avatar
Timothy J. Baek committed
545
546
547
548
        )

        r.raise_for_status()

549
550
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
551
            streaming = True
552
            return StreamingResponse(
553
554
                r.content,
                status_code=r.status,
555
                headers=dict(r.headers),
556
557
558
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
559
560
            )
        else:
561
            response_data = await r.json()
562
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
563
    except Exception as e:
564
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
565
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
566
567
        if r is not None:
            try:
568
                res = await r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
569
                print(res)
Timothy J. Baek's avatar
Timothy J. Baek committed
570
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
571
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
572
573
            except:
                error_detail = f"External: {e}"
574
575
576
577
578
579
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()