main.py 17.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
from pydantic import BaseModel
12
from starlette.background import BackgroundTask
Timothy J. Baek's avatar
Timothy J. Baek committed
13

14
15
from apps.webui.models.models import Models
from apps.webui.models.users import Users
Timothy J. Baek's avatar
Timothy J. Baek committed
16
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
from utils.utils import (
    decode_token,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
19
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
20
21
22
    get_verified_user,
    get_admin_user,
)
23
24
from utils.task import prompt_template

25
from config import (
26
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
27
    ENABLE_OPENAI_API,
28
    AIOHTTP_CLIENT_TIMEOUT,
29
30
31
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    ENABLE_MODEL_FILTER,
33
    MODEL_FILTER_LIST,
34
    AppConfig,
35
)
Timothy J. Baek's avatar
Timothy J. Baek committed
36
37
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
38
39
40

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
41

42
43
44
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
47
48
49
50
51
52
53
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
54

55
56
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
59
60

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
61
62
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
71
72
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
73

Timothy J. Baek's avatar
Timothy J. Baek committed
74
75
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
76
77


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
99
100


Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
103
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
104

Timothy J. Baek's avatar
Timothy J. Baek committed
105

Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
108
    await get_all_models()
109
110
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
111
112


Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
115
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117
118
119


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
120
121
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123


Timothy J. Baek's avatar
Timothy J. Baek committed
124
@app.post("/audio/speech")
125
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
    idx = None
    try:
128
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
132
133
134
135
136
137
138
139
140
141
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
142
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
143
        headers["Content-Type"] = "application/json"
144
145
146
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
147
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149
        try:
            r = requests.post(
150
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
153
154
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
155

Timothy J. Baek's avatar
Timothy J. Baek committed
156
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
157

Timothy J. Baek's avatar
Timothy J. Baek committed
158
159
160
161
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
162

Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
165

Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
168

Timothy J. Baek's avatar
Timothy J. Baek committed
169
        except Exception as e:
170
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
178
179
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
186
187


Timothy J. Baek's avatar
Timothy J. Baek committed
188
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
189
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
190
    try:
191
        headers = {"Authorization": f"Bearer {key}"}
192
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
193
194
            async with session.get(url, headers=headers) as response:
                return await response.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196
    except Exception as e:
        # Handle connection error here
197
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
200
        return None


201
202
203
204
205
206
207
208
209
210
async def cleanup_response(
    response: Optional[aiohttp.ClientResponse],
    session: Optional[aiohttp.ClientSession],
):
    if response:
        response.close()
    if session:
        await session.close()


Timothy J. Baek's avatar
Timothy J. Baek committed
211
def merge_models_lists(model_lists):
212
    log.debug(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
216
217
218
        if models is not None and "error" not in models:
            merged_list.extend(
                [
219
220
                    {
                        **model,
221
                        "name": model.get("name", model["id"]),
222
223
224
225
                        "owned_by": "openai",
                        "openai": model,
                        "urlIdx": idx,
                    }
Timothy J. Baek's avatar
Timothy J. Baek committed
226
                    for model in models
227
                    if "api.openai.com"
228
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
229
230
231
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
232

Timothy J. Baek's avatar
Timothy J. Baek committed
233
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
234
235


Timothy J. Baek's avatar
Timothy J. Baek committed
236
async def get_all_models(raw: bool = False):
237
    log.info("get_all_models()")
238

239
    if (
240
241
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
242
    ) or not app.state.config.ENABLE_OPENAI_API:
243
244
        models = {"data": []}
    else:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        # Check if API KEYS length is same than API URLS length
        if len(app.state.config.OPENAI_API_KEYS) != len(
            app.state.config.OPENAI_API_BASE_URLS
        ):
            # if there are more keys than urls, remove the extra keys
            if len(app.state.config.OPENAI_API_KEYS) > len(
                app.state.config.OPENAI_API_BASE_URLS
            ):
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
                    : len(app.state.config.OPENAI_API_BASE_URLS)
                ]
            # if there are more urls than keys, add empty keys
            else:
                app.state.config.OPENAI_API_KEYS += [
                    ""
                    for _ in range(
                        len(app.state.config.OPENAI_API_BASE_URLS)
                        - len(app.state.config.OPENAI_API_KEYS)
                    )
                ]

266
        tasks = [
267
268
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
269
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
270

271
        responses = await asyncio.gather(*tasks)
272
        log.debug(f"get_all_models:responses() {responses}")
Timothy J. Baek's avatar
Timothy J. Baek committed
273

Timothy J. Baek's avatar
Timothy J. Baek committed
274
275
276
        if raw:
            return responses

277
278
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
281
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
282
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
285
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
286
287
288
                        responses,
                    )
                )
289
290
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
291

292
        log.debug(f"models: {models}")
293
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
294

295
296
297
    return models


Timothy J. Baek's avatar
Timothy J. Baek committed
298
299
@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
300
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
301
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
302
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
303
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
304
            if user.role == "user":
305
306
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
307
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
308
309
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
310
311
312
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
313
    else:
314
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
315
316
317
318
319
        key = app.state.config.OPENAI_API_KEYS[url_idx]

        headers = {}
        headers["Authorization"] = f"Bearer {key}"
        headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
322

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
323
        try:
324
            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
327
328
329
330
331
332
333
334
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
335
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
336
337
338
339
340
341
342
343
344
345
346
347
348
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350


Timothy J. Baek's avatar
Timothy J. Baek committed
351
352
353
354
355
356
357
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
    form_data: dict,
    url_idx: Optional[int] = None,
    user=Depends(get_verified_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
358
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
359
    payload = {**form_data}
Timothy J. Baek's avatar
Timothy J. Baek committed
360

Timothy J. Baek's avatar
Timothy J. Baek committed
361
362
    model_id = form_data.get("model")
    model_info = Models.get_model_by_id(model_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
363

Timothy J. Baek's avatar
Timothy J. Baek committed
364
365
366
    if model_info:
        if model_info.base_model_id:
            payload["model"] = model_info.base_model_id
Timothy J. Baek's avatar
Timothy J. Baek committed
367

Timothy J. Baek's avatar
Timothy J. Baek committed
368
369
370
371
372
        model_info.params = model_info.params.model_dump()

        if model_info.params:
            if model_info.params.get("temperature", None) is not None:
                payload["temperature"] = float(model_info.params.get("temperature"))
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
373

Timothy J. Baek's avatar
Timothy J. Baek committed
374
375
            if model_info.params.get("top_p", None):
                payload["top_p"] = int(model_info.params.get("top_p", None))
376

Timothy J. Baek's avatar
Timothy J. Baek committed
377
378
            if model_info.params.get("max_tokens", None):
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
Timothy J. Baek's avatar
Timothy J. Baek committed
379

Timothy J. Baek's avatar
Timothy J. Baek committed
380
381
382
383
            if model_info.params.get("frequency_penalty", None):
                payload["frequency_penalty"] = int(
                    model_info.params.get("frequency_penalty", None)
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
384

Timothy J. Baek's avatar
Timothy J. Baek committed
385
386
387
388
389
390
391
392
393
394
395
396
            if model_info.params.get("seed", None):
                payload["seed"] = model_info.params.get("seed", None)

            if model_info.params.get("stop", None):
                payload["stop"] = (
                    [
                        bytes(stop, "utf-8").decode("unicode_escape")
                        for stop in model_info.params["stop"]
                    ]
                    if model_info.params.get("stop", None)
                    else None
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
397

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        system = model_info.params.get("system", None)
        if system:
            system = prompt_template(
                system,
                **(
                    {
                        "user_name": user.name,
                        "user_location": (
                            user.info.get("location") if user.info else None
                        ),
                    }
                    if user
                    else {}
                ),
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
413
414
415
416
417
            # Check if the payload already has a system message
            # If not, add a system message to the payload
            if payload.get("messages"):
                for message in payload["messages"]:
                    if message.get("role") == "system":
418
                        message["content"] = system + message["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
419
420
421
422
423
424
                        break
                else:
                    payload["messages"].insert(
                        0,
                        {
                            "role": "system",
425
                            "content": system,
Timothy J. Baek's avatar
Timothy J. Baek committed
426
427
                        },
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
428

Timothy J. Baek's avatar
Timothy J. Baek committed
429
430
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
431

Timothy J. Baek's avatar
Timothy J. Baek committed
432
433
    model = app.state.MODELS[payload.get("model")]
    idx = model["urlIdx"]
Timothy J. Baek's avatar
Timothy J. Baek committed
434

Timothy J. Baek's avatar
Timothy J. Baek committed
435
    if "pipeline" in model and model.get("pipeline"):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
436
437
438
439
440
441
        payload["user"] = {
            "name": user.name,
            "id": user.id,
            "email": user.email,
            "role": user.role,
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
442

Timothy J. Baek's avatar
Timothy J. Baek committed
443
444
445
446
447
448
    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
    # This is a workaround until OpenAI fixes the issue with this model
    if payload.get("model") == "gpt-4-vision-preview":
        if "max_tokens" not in payload:
            payload["max_tokens"] = 4000
        log.debug("Modified payload:", payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
449

Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
    # Convert the modified body back to JSON
    payload = json.dumps(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
452

453
    log.debug(payload)
Timothy J. Baek's avatar
Timothy J. Baek committed
454

Timothy J. Baek's avatar
Timothy J. Baek committed
455
456
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
457

Timothy J. Baek's avatar
Timothy J. Baek committed
458
459
460
    headers = {}
    headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
461

Timothy J. Baek's avatar
Timothy J. Baek committed
462
463
464
465
466
    r = None
    session = None
    streaming = False

    try:
467
468
469
        session = aiohttp.ClientSession(
            trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
470
471
472
473
474
475
        r = await session.request(
            method="POST",
            url=f"{url}/chat/completions",
            data=payload,
            headers=headers,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
476

Timothy J. Baek's avatar
Timothy J. Baek committed
477
        r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
478

Timothy J. Baek's avatar
Timothy J. Baek committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            streaming = True
            return StreamingResponse(
                r.content,
                status_code=r.status,
                headers=dict(r.headers),
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
            )
        else:
            response_data = await r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = await r.json()
                print(res)
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    idx = 0

    body = await request.body()
Timothy J. Baek's avatar
Timothy J. Baek committed
517

518
519
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
520
521
522

    target_url = f"{url}/{path}"

523
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
524
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
525
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
526

Timothy J. Baek's avatar
Timothy J. Baek committed
527
    r = None
528
529
    session = None
    streaming = False
Timothy J. Baek's avatar
Timothy J. Baek committed
530

Timothy J. Baek's avatar
Timothy J. Baek committed
531
    try:
532
        session = aiohttp.ClientSession(trust_env=True)
533
        r = await session.request(
Jun Siang Cheah's avatar
Jun Siang Cheah committed
534
535
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
Timothy J. Baek committed
536
            data=body,
Jun Siang Cheah's avatar
Jun Siang Cheah committed
537
            headers=headers,
Timothy J. Baek's avatar
Timothy J. Baek committed
538
539
540
541
        )

        r.raise_for_status()

542
543
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
544
            streaming = True
545
            return StreamingResponse(
546
547
                r.content,
                status_code=r.status,
548
                headers=dict(r.headers),
549
550
551
                background=BackgroundTask(
                    cleanup_response, response=r, session=session
                ),
552
553
            )
        else:
554
            response_data = await r.json()
555
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
556
    except Exception as e:
557
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
558
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
559
560
        if r is not None:
            try:
561
                res = await r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
562
                print(res)
Timothy J. Baek's avatar
Timothy J. Baek committed
563
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
564
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
565
566
            except:
                error_detail = f"External: {e}"
567
568
569
570
571
572
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
    finally:
        if not streaming and session:
            if r:
                r.close()
            await session.close()