main.py 12.2 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

13
from apps.web.models.models import Models
Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
    ENABLE_OPENAI_API,
25
26
27
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
28
    ENABLE_MODEL_FILTER,
29
    MODEL_FILTER_LIST,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50

51
52
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
55
app.state.MODEL_CONFIG = Models.get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
56

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
57
58

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
59
60
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
64

Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
68
69
70
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
71

Timothy J. Baek's avatar
Timothy J. Baek committed
72
73
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
74
75


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94


Timothy J. Baek's avatar
Timothy J. Baek committed
95
96
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
97
98


Timothy J. Baek's avatar
Timothy J. Baek committed
99
100
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
101
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
102

Timothy J. Baek's avatar
Timothy J. Baek committed
103

Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
106
    await get_all_models()
107
108
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
109
110


Timothy J. Baek's avatar
Timothy J. Baek committed
111
112
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
113
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
118
119
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
120
121


Timothy J. Baek's avatar
Timothy J. Baek committed
122
@app.post("/audio/speech")
123
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
    idx = None
    try:
126
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
129
130
131
132
133
134
135
136
137
138
139
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
140
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
141
        headers["Content-Type"] = "application/json"
142
143
144
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
145
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
        try:
            r = requests.post(
148
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
151
152
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
153

Timothy J. Baek's avatar
Timothy J. Baek committed
154
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
155

Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
159
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
160

Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
163

Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
166

Timothy J. Baek's avatar
Timothy J. Baek committed
167
        except Exception as e:
168
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
171
172
173
174
175
176
177
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
183

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185


Timothy J. Baek's avatar
Timothy J. Baek committed
186
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
188
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
189
190
        if key != "":
            headers = {"Authorization": f"Bearer {key}"}
Timothy J. Baek's avatar
Timothy J. Baek committed
191
            async with aiohttp.ClientSession(timeout=timeout) as session:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
192
193
194
195
                async with session.get(url, headers=headers) as response:
                    return await response.json()
        else:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
196
197
    except Exception as e:
        # Handle connection error here
198
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
202
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
203
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
206
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
207
208
209
210
211
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
212
                    if "api.openai.com"
213
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
216
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
217

Timothy J. Baek's avatar
Timothy J. Baek committed
218
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220


Timothy J. Baek's avatar
Timothy J. Baek committed
221
async def get_all_models():
222
    log.info("get_all_models()")
223

224
    if (
225
226
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
227
    ) or not app.state.config.ENABLE_OPENAI_API:
228
229
230
        models = {"data": []}
    else:
        tasks = [
231
232
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
233
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
234

235
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
236
237
        log.info(f"get_all_models:responses() {responses}")

238
239
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
240
241
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
242
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
243
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
244
245
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
246
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
249
                        responses,
                    )
                )
250
251
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
252

253
254
255
        for model in models["data"]:
            add_custom_info_to_model(model)

256
        log.info(f"models: {models}")
257
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
258

259
260
261
262
263
    return models


def add_custom_info_to_model(model: dict):
    model["custom_info"] = next(
264
        (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None
265
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
266

Timothy J. Baek's avatar
Timothy J. Baek committed
267
268
269

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
270
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
271
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
272
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
273
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
274
            if user.role == "user":
275
276
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
277
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
278
279
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
282
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
283
    else:
284
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
285
286
287

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
288
289
290
291
292
293
294
295
296
297
298
299
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
300
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
303
304
305
306
307
308
309
310
311
312
313
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
314
315


Timothy J. Baek's avatar
Timothy J. Baek committed
316
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
317
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
318
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
319
    pipeline = False
Timothy J. Baek's avatar
Timothy J. Baek committed
320

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
321
322
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
323
324
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
325
326
327
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
328
329
330
331
332
333
334
335
336
        model = app.state.MODELS[body.get("model")]

        idx = model["urlIdx"]

        if "pipeline" in model:
            pipeline = model.get("pipeline")

        if pipeline:
            body["user"] = {"name": user.name, "id": user.id}
Timothy J. Baek's avatar
Timothy J. Baek committed
337

338
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
339
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
340
        if body.get("model") == "gpt-4-vision-preview":
341
342
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
343
            log.debug("Modified body_dict:", body)
344

Sakkus's avatar
Sakkus committed
345
        # Fix for ChatGPT calls failing because the num_ctx key is in body
346
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
347
348
349
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
350
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
351

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
352
353
354
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
355
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
356

357
358
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
359
360
361
362
363
364

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

365
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
366
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
367
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
368

Timothy J. Baek's avatar
Timothy J. Baek committed
369
370
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
371
372
373
374
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
375
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
376
377
378
379
380
381
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

382
383
384
385
386
387
388
389
390
391
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
392
    except Exception as e:
393
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
394
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
395
396
397
398
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
399
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
402
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
403
404
405
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )