main.py 33.2 KB
Newer Older
1
import uuid
2
from contextlib import asynccontextmanager
3
4
5

from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
8
from bs4 import BeautifulSoup
import json
import markdown
9
import time
Timothy J. Baek's avatar
Timothy J. Baek committed
10
11
import os
import sys
12
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
13
import aiohttp
14
import requests
15
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
16

17
from fastapi import FastAPI, Request, Depends, status
Timothy J. Baek's avatar
Timothy J. Baek committed
18
from fastapi.staticfiles import StaticFiles
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
19
from fastapi.responses import JSONResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
20
21
22
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
23
from starlette.exceptions import HTTPException as StarletteHTTPException
Timothy J. Baek's avatar
Timothy J. Baek committed
24
from starlette.middleware.base import BaseHTTPMiddleware
25
26
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
27

Timothy J. Baek's avatar
Timothy J. Baek committed
28
29

from apps.socket.main import app as socket_app
30
31
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Timothy J. Baek's avatar
Timothy J. Baek committed
33
from apps.audio.main import app as audio_app
Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
36
from apps.webui.main import app as webui_app
Timothy J. Baek's avatar
Timothy J. Baek committed
37

Timothy J. Baek's avatar
Timothy J. Baek committed
38
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
39
from pydantic import BaseModel
40
from typing import List, Optional
Timothy J. Baek's avatar
Timothy J. Baek committed
41

42
43
44
45
46
47
48
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from utils.misc import parse_duration
from utils.utils import (
    get_admin_user,
    get_verified_user,
49
50
    get_current_user,
    get_http_authorization_cred,
51
52
53
    get_password_hash,
    create_token,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
54
from apps.rag.utils import rag_messages
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
from config import (
57
    CONFIG_DATA,
58
    WEBUI_NAME,
59
    WEBUI_URL,
60
    WEBUI_AUTH,
61
62
63
64
    ENV,
    VERSION,
    CHANGELOG,
    FRONTEND_BUILD_DIR,
65
66
    CACHE_DIR,
    STATIC_DIR,
67
68
    ENABLE_OPENAI_API,
    ENABLE_OLLAMA_API,
Timothy J. Baek's avatar
Timothy J. Baek committed
69
    ENABLE_MODEL_FILTER,
70
    MODEL_FILTER_LIST,
71
72
    GLOBAL_LOG_LEVEL,
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
Timothy J. Baek committed
73
    WEBHOOK_URL,
74
    ENABLE_ADMIN_EXPORT,
75
    AppConfig,
76
    WEBUI_BUILD_HASH,
77
    OAUTH_PROVIDERS,
78
79
80
    ENABLE_OAUTH_SIGNUP,
    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
    WEBUI_SECRET_KEY,
81
    WEBUI_SESSION_COOKIE_SAME_SITE,
82
)
83
84
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
85

86
87
88
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
Timothy J. Baek's avatar
Timothy J. Baek committed
89

90

Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
93
94
95
96
97
98
99
100
101
class SPAStaticFiles(StaticFiles):
    async def get_response(self, path: str, scope):
        try:
            return await super().get_response(path, scope)
        except (HTTPException, StarletteHTTPException) as ex:
            if ex.status_code == 404:
                return await super().get_response("index.html", scope)
            else:
                raise ex


Timothy J. Baek's avatar
Timothy J. Baek committed
102
print(
Timothy J. Baek's avatar
Timothy J. Baek committed
103
    rf"""
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
107
108
109
110
111
  ___                    __        __   _     _   _ ___ 
 / _ \ _ __   ___ _ __   \ \      / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \   \ \ /\ / / _ \ '_ \| | | || | 
| |_| | |_) |  __/ | | |   \ V  V /  __/ |_) | |_| || | 
 \___/| .__/ \___|_| |_|    \_/\_/ \___|_.__/ \___/|___|
      |_|                                               

      
112
v{VERSION} - building the best open-source AI user interface.
113
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117
https://github.com/open-webui/open-webui
"""
)

Timothy J. Baek's avatar
Timothy J. Baek committed
118

119
120
121
122
123
124
125
126
@asynccontextmanager
async def lifespan(app: FastAPI):
    yield


app = FastAPI(
    docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
)
Timothy J. Baek's avatar
Timothy J. Baek committed
127

128
app.state.config = AppConfig()
129
130
131
132

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API

133
134
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
135

136

137
app.state.config.WEBHOOK_URL = WEBHOOK_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
138

Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
141

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
origins = ["*"]

Timothy J. Baek's avatar
Timothy J. Baek committed
144
# Custom middleware to add security headers
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
145
146
147
148
149
150
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
#     async def dispatch(self, request: Request, call_next):
#         response: Response = await call_next(request)
#         response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
#         response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
#         return response
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
153
# app.add_middleware(SecurityHeadersMiddleware)
Timothy J. Baek's avatar
Timothy J. Baek committed
154
155


Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
class RAGMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
158
159
        return_citations = False

160
        if request.method == "POST" and (
161
162
            "/ollama/api/chat" in request.url.path
            or "/chat/completions" in request.url.path
163
        ):
164
            log.debug(f"request.url.path: {request.url.path}")
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
165

Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
169
170
171
172
            # Read the original request body
            body = await request.body()
            # Decode body to string
            body_str = body.decode("utf-8")
            # Parse string to JSON
            data = json.loads(body_str) if body_str else {}

173
174
175
176
            return_citations = data.get("citations", False)
            if "citations" in data:
                del data["citations"]

Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
            # Example: Add a new key-value pair or modify existing ones
            # data["modified"] = True  # Example modification
            if "docs" in data:
Timothy J. Baek's avatar
Timothy J. Baek committed
180
                data = {**data}
181
                data["messages"], citations = rag_messages(
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
                    docs=data["docs"],
                    messages=data["messages"],
184
                    template=rag_app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
185
                    embedding_function=rag_app.state.EMBEDDING_FUNCTION,
186
                    k=rag_app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
187
                    reranking_function=rag_app.state.sentence_transformer_rf,
188
189
                    r=rag_app.state.config.RELEVANCE_THRESHOLD,
                    hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
190
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
191
192
                del data["docs"]

193
194
195
                log.debug(
                    f"data['messages']: {data['messages']}, citations: {citations}"
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
196

Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
            modified_body_bytes = json.dumps(data).encode("utf-8")

Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
202
203
204
205
206
207
208
209
210
            # Replace the request body with the modified one
            request._body = modified_body_bytes

            # Set custom header to ensure content-length matches new body length
            request.headers.__dict__["_list"] = [
                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
                *[
                    (k, v)
                    for k, v in request.headers.raw
                    if k.lower() != b"content-length"
                ],
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
211
212

        response = await call_next(request)
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        if return_citations:
            # Inject the citations into the response
            if isinstance(response, StreamingResponse):
                # If it's a streaming response, inject it as SSE event or NDJSON line
                content_type = response.headers.get("Content-Type")
                if "text/event-stream" in content_type:
                    return StreamingResponse(
                        self.openai_stream_wrapper(response.body_iterator, citations),
                    )
                if "application/x-ndjson" in content_type:
                    return StreamingResponse(
                        self.ollama_stream_wrapper(response.body_iterator, citations),
                    )

Timothy J. Baek's avatar
Timothy J. Baek committed
228
229
230
231
232
        return response

    async def _receive(self, body: bytes):
        return {"type": "http.request", "body": body, "more_body": False}

233
234
235
236
237
238
239
240
241
242
    async def openai_stream_wrapper(self, original_generator, citations):
        yield f"data: {json.dumps({'citations': citations})}\n\n"
        async for data in original_generator:
            yield data

    async def ollama_stream_wrapper(self, original_generator, citations):
        yield f"{json.dumps({'citations': citations})}\n"
        async for data in original_generator:
            yield data

Timothy J. Baek's avatar
Timothy J. Baek committed
243
244
245
246

app.add_middleware(RAGMiddleware)


Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
249
class PipelineMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        if request.method == "POST" and (
250
251
            "/ollama/api/chat" in request.url.path
            or "/chat/completions" in request.url.path
Timothy J. Baek's avatar
Timothy J. Baek committed
252
253
254
255
256
257
258
259
260
261
262
        ):
            log.debug(f"request.url.path: {request.url.path}")

            # Read the original request body
            body = await request.body()
            # Decode body to string
            body_str = body.decode("utf-8")
            # Parse string to JSON
            data = json.loads(body_str) if body_str else {}

            model_id = data["model"]
263
            filters = [
Timothy J. Baek's avatar
Timothy J. Baek committed
264
265
266
                model
                for model in app.state.MODELS.values()
                if "pipeline" in model
Timothy J. Baek's avatar
Timothy J. Baek committed
267
                and "type" in model["pipeline"]
268
                and model["pipeline"]["type"] == "filter"
269
270
271
                and (
                    model["pipeline"]["pipelines"] == ["*"]
                    or any(
272
273
                        model_id == target_model_id
                        for target_model_id in model["pipeline"]["pipelines"]
274
275
                    )
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
276
            ]
277
            sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
Timothy J. Baek's avatar
Timothy J. Baek committed
278

279
            user = None
280
            if len(sorted_filters) > 0:
281
282
283
284
285
286
287
288
289
290
                try:
                    user = get_current_user(
                        get_http_authorization_cred(
                            request.headers.get("Authorization")
                        )
                    )
                    user = {"id": user.id, "name": user.name, "role": user.role}
                except:
                    pass

Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
293
294
295
            model = app.state.MODELS[model_id]

            if "pipeline" in model:
                sorted_filters.append(model)

296
            for filter in sorted_filters:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
297
                r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
298
                try:
299
                    urlIdx = filter["urlIdx"]
Timothy J. Baek's avatar
Timothy J. Baek committed
300
301
302
303
304
305
306

                    url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
                    key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]

                    if key != "":
                        headers = {"Authorization": f"Bearer {key}"}
                        r = requests.post(
Timothy J. Baek's avatar
Timothy J. Baek committed
307
                            f"{url}/{filter['id']}/filter/inlet",
Timothy J. Baek's avatar
Timothy J. Baek committed
308
309
                            headers=headers,
                            json={
310
                                "user": user,
Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
313
314
315
316
317
318
                                "body": data,
                            },
                        )

                        r.raise_for_status()
                        data = r.json()
                except Exception as e:
                    # Handle connection error here
319
                    print(f"Connection error: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333

                    if r is not None:
                        try:
                            res = r.json()
                            if "detail" in res:
                                return JSONResponse(
                                    status_code=r.status_code,
                                    content=res,
                                )
                        except:
                            pass

                    else:
                        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
334

Timothy J. Baek's avatar
Timothy J. Baek committed
335
336
337
338
339
340
            if "pipeline" not in app.state.MODELS[model_id]:
                if "chat_id" in data:
                    del data["chat_id"]

                if "title" in data:
                    del data["title"]
341

Timothy J. Baek's avatar
Timothy J. Baek committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            modified_body_bytes = json.dumps(data).encode("utf-8")
            # Replace the request body with the modified one
            request._body = modified_body_bytes
            # Set custom header to ensure content-length matches new body length
            request.headers.__dict__["_list"] = [
                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
                *[
                    (k, v)
                    for k, v in request.headers.raw
                    if k.lower() != b"content-length"
                ],
            ]

        response = await call_next(request)
        return response

    async def _receive(self, body: bytes):
        return {"type": "http.request", "body": body, "more_body": False}


app.add_middleware(PipelineMiddleware)


Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
367
368
369
370
371
372
373
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
374
375
@app.middleware("http")
async def check_url(request: Request, call_next):
Timothy J. Baek's avatar
Timothy J. Baek committed
376
377
378
379
380
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass

Timothy J. Baek's avatar
Timothy J. Baek committed
381
382
383
384
385
386
387
388
    start_time = int(time.time())
    response = await call_next(request)
    process_time = int(time.time()) - start_time
    response.headers["X-Process-Time"] = str(process_time)

    return response


Timothy J. Baek's avatar
Timothy J. Baek committed
389
390
391
392
393
394
@app.middleware("http")
async def update_embedding_function(request: Request, call_next):
    response = await call_next(request)
    if "/embedding/update" in request.url.path:
        webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
395

Timothy J. Baek's avatar
Timothy J. Baek committed
396

Timothy J. Baek's avatar
Timothy J. Baek committed
397
398
399
app.mount("/ws", socket_app)


400
app.mount("/ollama", ollama_app)
401
app.mount("/openai", openai_app)
Timothy J. Baek's avatar
Timothy J. Baek committed
402

Timothy J. Baek's avatar
Timothy J. Baek committed
403
app.mount("/images/api/v1", images_app)
Timothy J. Baek's avatar
Timothy J. Baek committed
404
app.mount("/audio/api/v1", audio_app)
Timothy J. Baek's avatar
Timothy J. Baek committed
405
406
app.mount("/rag/api/v1", rag_app)

Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
409
410
app.mount("/api/v1", webui_app)

webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION

411

Timothy J. Baek's avatar
Timothy J. Baek committed
412
async def get_all_models():
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    openai_models = []
    ollama_models = []

    if app.state.config.ENABLE_OPENAI_API:
        openai_models = await get_openai_models()

        openai_models = openai_models["data"]

    if app.state.config.ENABLE_OLLAMA_API:
        ollama_models = await get_ollama_models()

        ollama_models = [
            {
                "id": model["model"],
                "name": model["name"],
                "object": "model",
                "created": int(time.time()),
                "owned_by": "ollama",
                "ollama": model,
            }
            for model in ollama_models["models"]
        ]

    models = openai_models + ollama_models
    custom_models = Models.get_all_models()

    for custom_model in custom_models:
        if custom_model.base_model_id == None:
            for model in models:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
442
443
444
445
                if (
                    custom_model.id == model["id"]
                    or custom_model.id == model["id"].split(":")[0]
                ):
446
447
448
                    model["name"] = custom_model.name
                    model["info"] = custom_model.model_dump()
        else:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
449
450
            owned_by = "openai"
            for model in models:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
451
452
453
454
                if (
                    custom_model.base_model_id == model["id"]
                    or custom_model.base_model_id == model["id"].split(":")[0]
                ):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
455
456
457
                    owned_by = model["owned_by"]
                    break

458
459
460
461
462
463
            models.append(
                {
                    "id": custom_model.id,
                    "name": custom_model.name,
                    "object": "model",
                    "created": custom_model.created_at,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
464
                    "owned_by": owned_by,
465
                    "info": custom_model.model_dump(),
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
466
                    "preset": True,
467
468
469
                }
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
470
471
472
473
474
475
476
477
478
479
    app.state.MODELS = {model["id"]: model for model in models}

    webui_app.state.MODELS = app.state.MODELS

    return models


@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
    models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
480

481
    # Filter out filter pipelines
Timothy J. Baek's avatar
Timothy J. Baek committed
482
483
484
    models = [
        model
        for model in models
Timothy J. Baek's avatar
Timothy J. Baek committed
485
        if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
Timothy J. Baek's avatar
Timothy J. Baek committed
486
487
    ]

488
489
490
491
492
493
494
495
496
497
498
499
500
    if app.state.config.ENABLE_MODEL_FILTER:
        if user.role == "user":
            models = list(
                filter(
                    lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
                    models,
                )
            )
            return {"data": models}

    return {"data": models}


501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
    data = form_data
    model_id = data["model"]

    filters = [
        model
        for model in app.state.MODELS.values()
        if "pipeline" in model
        and "type" in model["pipeline"]
        and model["pipeline"]["type"] == "filter"
        and (
            model["pipeline"]["pipelines"] == ["*"]
            or any(
                model_id == target_model_id
                for target_model_id in model["pipeline"]["pipelines"]
            )
        )
    ]
    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])

522
523
    print(model_id)

Timothy J. Baek's avatar
Timothy J. Baek committed
524
525
526
527
    if model_id in app.state.MODELS:
        model = app.state.MODELS[model_id]
        if "pipeline" in model:
            sorted_filters = [model] + sorted_filters
Timothy J. Baek's avatar
Timothy J. Baek committed
528

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    for filter in sorted_filters:
        r = None
        try:
            urlIdx = filter["urlIdx"]

            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]

            if key != "":
                headers = {"Authorization": f"Bearer {key}"}
                r = requests.post(
                    f"{url}/{filter['id']}/filter/outlet",
                    headers=headers,
                    json={
                        "user": {"id": user.id, "name": user.name, "role": user.role},
                        "body": data,
                    },
                )

                r.raise_for_status()
                data = r.json()
        except Exception as e:
            # Handle connection error here
            print(f"Connection error: {e}")

            if r is not None:
                try:
                    res = r.json()
                    if "detail" in res:
                        return JSONResponse(
                            status_code=r.status_code,
                            content=res,
                        )
                except:
                    pass

            else:
                pass

    return data


571
572
@app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
573
574
575
    responses = await get_openai_models(raw=True)

    print(responses)
Timothy J. Baek's avatar
Timothy J. Baek committed
576
577
578
579
580
    urlIdxs = [
        idx
        for idx, response in enumerate(responses)
        if response != None and "pipelines" in response
    ]
581
582
583
584
585
586
587
588
589
590
591
592

    return {
        "data": [
            {
                "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
                "idx": urlIdx,
            }
            for urlIdx in urlIdxs
        ]
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
class AddPipelineForm(BaseModel):
    url: str
    urlIdx: int


@app.post("/api/pipelines/add")
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):

    r = None
    try:
        urlIdx = form_data.urlIdx

        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]

        headers = {"Authorization": f"Bearer {key}"}
        r = requests.post(
            f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
        )

        r.raise_for_status()
        data = r.json()

        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")

        detail = "Pipeline not found"
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass

        raise HTTPException(
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
            detail=detail,
        )


class DeletePipelineForm(BaseModel):
    id: str
    urlIdx: int


@app.delete("/api/pipelines/delete")
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):

    r = None
    try:
        urlIdx = form_data.urlIdx

        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]

        headers = {"Authorization": f"Bearer {key}"}
        r = requests.delete(
            f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
        )

        r.raise_for_status()
        data = r.json()

        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")

        detail = "Pipeline not found"
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass

        raise HTTPException(
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
            detail=detail,
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
679
@app.get("/api/pipelines")
680
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
681
682
683
    r = None
    try:
        urlIdx
684

Timothy J. Baek's avatar
Timothy J. Baek committed
685
686
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
687

Timothy J. Baek's avatar
Timothy J. Baek committed
688
689
        headers = {"Authorization": f"Bearer {key}"}
        r = requests.get(f"{url}/pipelines", headers=headers)
690

Timothy J. Baek's avatar
Timothy J. Baek committed
691
692
        r.raise_for_status()
        data = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
693

Timothy J. Baek's avatar
Timothy J. Baek committed
694
695
696
697
        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
698

Timothy J. Baek's avatar
Timothy J. Baek committed
699
700
701
702
703
704
705
706
707
708
709
710
711
        detail = "Pipeline not found"
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass

        raise HTTPException(
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
            detail=detail,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
712
713


Timothy J. Baek's avatar
Timothy J. Baek committed
714
715
716
717
718
719
720
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
):
    models = await get_all_models()
    r = None
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
721

Timothy J. Baek's avatar
Timothy J. Baek committed
722
723
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
Timothy J. Baek's avatar
Timothy J. Baek committed
724

Timothy J. Baek's avatar
Timothy J. Baek committed
725
726
        headers = {"Authorization": f"Bearer {key}"}
        r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
727

Timothy J. Baek's avatar
Timothy J. Baek committed
728
729
        r.raise_for_status()
        data = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
730

Timothy J. Baek's avatar
Timothy J. Baek committed
731
732
733
734
        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
735

Timothy J. Baek's avatar
Timothy J. Baek committed
736
        detail = "Pipeline not found"
737

Timothy J. Baek's avatar
Timothy J. Baek committed
738
739
740
741
742
743
744
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass
Timothy J. Baek's avatar
Timothy J. Baek committed
745
746

        raise HTTPException(
Timothy J. Baek's avatar
Timothy J. Baek committed
747
748
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
            detail=detail,
Timothy J. Baek's avatar
Timothy J. Baek committed
749
750
751
752
        )


@app.get("/api/pipelines/{pipeline_id}/valves/spec")
Timothy J. Baek's avatar
Timothy J. Baek committed
753
754
755
async def get_pipeline_valves_spec(
    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
):
Timothy J. Baek's avatar
Timothy J. Baek committed
756
757
    models = await get_all_models()

Timothy J. Baek's avatar
Timothy J. Baek committed
758
759
760
761
    r = None
    try:
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
Timothy J. Baek's avatar
Timothy J. Baek committed
762

Timothy J. Baek's avatar
Timothy J. Baek committed
763
764
        headers = {"Authorization": f"Bearer {key}"}
        r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
Timothy J. Baek's avatar
Timothy J. Baek committed
765

Timothy J. Baek's avatar
Timothy J. Baek committed
766
767
        r.raise_for_status()
        data = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
768

Timothy J. Baek's avatar
Timothy J. Baek committed
769
770
771
772
        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
773

Timothy J. Baek's avatar
Timothy J. Baek committed
774
775
776
777
778
779
780
781
        detail = "Pipeline not found"
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass
782

Timothy J. Baek's avatar
Timothy J. Baek committed
783
        raise HTTPException(
Timothy J. Baek's avatar
Timothy J. Baek committed
784
785
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
            detail=detail,
Timothy J. Baek's avatar
Timothy J. Baek committed
786
787
788
789
790
        )


@app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves(
Timothy J. Baek's avatar
Timothy J. Baek committed
791
792
793
794
    urlIdx: Optional[int],
    pipeline_id: str,
    form_data: dict,
    user=Depends(get_admin_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
795
796
797
):
    models = await get_all_models()

Timothy J. Baek's avatar
Timothy J. Baek committed
798
799
800
801
    r = None
    try:
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
        key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
Timothy J. Baek's avatar
Timothy J. Baek committed
802

Timothy J. Baek's avatar
Timothy J. Baek committed
803
804
805
806
807
808
        headers = {"Authorization": f"Bearer {key}"}
        r = requests.post(
            f"{url}/{pipeline_id}/valves/update",
            headers=headers,
            json={**form_data},
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
809

Timothy J. Baek's avatar
Timothy J. Baek committed
810
811
        r.raise_for_status()
        data = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
812

Timothy J. Baek's avatar
Timothy J. Baek committed
813
814
815
816
        return {**data}
    except Exception as e:
        # Handle connection error here
        print(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
817

Timothy J. Baek's avatar
Timothy J. Baek committed
818
        detail = "Pipeline not found"
819

Timothy J. Baek's avatar
Timothy J. Baek committed
820
821
822
823
824
825
826
        if r is not None:
            try:
                res = r.json()
                if "detail" in res:
                    detail = res["detail"]
            except:
                pass
827

Timothy J. Baek's avatar
Timothy J. Baek committed
828
        raise HTTPException(
Timothy J. Baek's avatar
Timothy J. Baek committed
829
            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
Timothy J. Baek's avatar
Timothy J. Baek committed
830
            detail=detail,
Timothy J. Baek's avatar
Timothy J. Baek committed
831
832
833
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
834
835
@app.get("/api/config")
async def get_app_config():
836
    # Checking and Handling the Absence of 'ui' in CONFIG_DATA
837
838

    default_locale = "en-US"
839
840
841
842
    if "ui" in CONFIG_DATA:
        default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")

    # The Rest of the Function Now Uses the Variables Defined Above
Timothy J. Baek's avatar
Timothy J. Baek committed
843
844
    return {
        "status": True,
845
        "name": WEBUI_NAME,
Timothy J. Baek's avatar
Timothy J. Baek committed
846
        "version": VERSION,
847
        "default_locale": default_locale,
848
849
        "default_models": webui_app.state.config.DEFAULT_MODELS,
        "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
850
        "features": {
851
852
            "auth": WEBUI_AUTH,
            "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
853
            "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
Timothy J. Baek's avatar
Timothy J. Baek committed
854
            "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
855
            "enable_image_generation": images_app.state.config.ENABLED,
856
            "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
857
            "enable_admin_export": ENABLE_ADMIN_EXPORT,
858
        },
859
860
861
862
863
864
        "oauth": {
            "providers": {
                name: config.get("name", name)
                for name, config in OAUTH_PROVIDERS.items()
            }
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
865
866
867
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
868
869
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
870
    return {
871
872
        "enabled": app.state.config.ENABLE_MODEL_FILTER,
        "models": app.state.config.MODEL_FILTER_LIST,
873
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
874
875
876
877
878
879
880
881


class ModelFilterConfigForm(BaseModel):
    enabled: bool
    models: List[str]


@app.post("/api/config/model/filter")
Timothy J. Baek's avatar
Timothy J. Baek committed
882
async def update_model_filter_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
883
884
    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
885
886
    app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
    app.state.config.MODEL_FILTER_LIST = form_data.models
Timothy J. Baek's avatar
Timothy J. Baek committed
887

888
    return {
889
890
        "enabled": app.state.config.ENABLE_MODEL_FILTER,
        "models": app.state.config.MODEL_FILTER_LIST,
891
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
892
893


Timothy J. Baek's avatar
Timothy J. Baek committed
894
895
896
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
    return {
897
        "url": app.state.config.WEBHOOK_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
898
899
900
901
902
903
904
905
906
    }


class UrlForm(BaseModel):
    url: str


@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
907
908
    app.state.config.WEBHOOK_URL = form_data.url
    webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
909
    return {"url": app.state.config.WEBHOOK_URL}
910
911


912
913
914
915
916
917
918
@app.get("/api/version")
async def get_app_config():
    return {
        "version": VERSION,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
919
920
@app.get("/api/changelog")
async def get_app_changelog():
921
    return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
Timothy J. Baek's avatar
Timothy J. Baek committed
922
923


924
925
926
@app.get("/api/version/updates")
async def get_app_latest_release_version():
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
927
928
929
930
931
932
933
934
935
936
        async with aiohttp.ClientSession() as session:
            async with session.get(
                "https://api.github.com/repos/open-webui/open-webui/releases/latest"
            ) as response:
                response.raise_for_status()
                data = await response.json()
                latest_version = data["tag_name"]

                return {"current": VERSION, "latest": latest_version[1:]}
    except aiohttp.ClientError as e:
937
938
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Timothy J. Baek's avatar
Timothy J. Baek committed
939
            detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
940
941
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
942

943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
############################
# OAuth Login & Callback
############################

oauth = OAuth()

for provider_name, provider_config in OAUTH_PROVIDERS.items():
    oauth.register(
        name=provider_name,
        client_id=provider_config["client_id"],
        client_secret=provider_config["client_secret"],
        server_metadata_url=provider_config["server_metadata_url"],
        client_kwargs={
            "scope": provider_config["scope"],
        },
    )

# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
    app.add_middleware(
963
964
965
966
        SessionMiddleware,
        secret_key=WEBUI_SECRET_KEY,
        session_cookie="oui-session",
        same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    )


@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
    if provider not in OAUTH_PROVIDERS:
        raise HTTPException(404)
    redirect_uri = request.url_for("oauth_callback", provider=provider)
    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)


@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
    if provider not in OAUTH_PROVIDERS:
        raise HTTPException(404)
    client = oauth.create_client(provider)
983
984
985
986
987
    try:
        token = await client.authorize_access_token(request)
    except Exception as e:
        log.error(f"OAuth callback error: {e}")
        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
988
989
990
991
992
993
994
995
996
997
998
999
    user_data: UserInfo = token["userinfo"]

    sub = user_data.get("sub")
    if not sub:
        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
    provider_sub = f"{provider}@{sub}"

    # Check if the user exists
    user = Users.get_user_by_oauth_sub(provider_sub)

    if not user:
        # If the user does not exist, check if merging is enabled
1000
        if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
            # Check if the user exists by email
            email = user_data.get("email", "").lower()
            if not email:
                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
            user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
            if user:
                # Update the user with the new oauth sub
                Users.update_user_oauth_sub_by_id(user.id, provider_sub)

    if not user:
        # If the user does not exist, check if signups are enabled
        if ENABLE_OAUTH_SIGNUP.value:
            user = Auths.insert_new_auth(
                email=user_data.get("email", "").lower(),
                password=get_password_hash(
                    str(uuid.uuid4())
                ),  # Random password, not used
                name=user_data.get("name", "User"),
                profile_image_url=user_data.get("picture", "/user.png"),
                role=webui_app.state.config.DEFAULT_USER_ROLE,
                oauth_sub=provider_sub,
            )

            if webui_app.state.config.WEBHOOK_URL:
                post_webhook(
                    webui_app.state.config.WEBHOOK_URL,
                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                    {
                        "action": "signup",
                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                        "user": user.model_dump_json(exclude_none=True),
                    },
                )
        else:
            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)

    jwt_token = create_token(
        data={"id": user.id},
        expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
    )

    # Redirect back to the frontend with the JWT token
    redirect_url = f"{request.base_url}auth#token={jwt_token}"
    return RedirectResponse(url=redirect_url)


1047
1048
1049
@app.get("/manifest.json")
async def get_manifest_json():
    return {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1050
1051
        "name": WEBUI_NAME,
        "short_name": WEBUI_NAME,
1052
1053
1054
1055
1056
        "start_url": "/",
        "display": "standalone",
        "background_color": "#343541",
        "theme_color": "#343541",
        "orientation": "portrait-primary",
Timothy J. Baek's avatar
Timothy J. Baek committed
1057
        "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
1058
1059
    }

Timothy J. Baek's avatar
Timothy J. Baek committed
1060

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
@app.get("/opensearch.xml")
async def get_opensearch_xml():
    xml_content = rf"""
    <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
    <ShortName>{WEBUI_NAME}</ShortName>
    <Description>Search {WEBUI_NAME}</Description>
    <InputEncoding>UTF-8</InputEncoding>
    <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
    <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
    <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
    </OpenSearchDescription>
    """
    return Response(content=xml_content, media_type="application/xml")


Timothy J. Baek's avatar
Timothy J. Baek committed
1076
1077
1078
1079
1080
@app.get("/health")
async def healthcheck():
    return {"status": True}


1081
1082
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
1083

1084
if os.path.exists(FRONTEND_BUILD_DIR):
Timothy J. Baek's avatar
Timothy J. Baek committed
1085
    mimetypes.add_type("text/javascript", ".js")
1086
1087
1088
1089
1090
1091
1092
1093
1094
    app.mount(
        "/",
        SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
        name="spa-static-files",
    )
else:
    log.warning(
        f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
    )