main.py 12 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

13
from apps.web.models.models import Models
Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
    ENABLE_OPENAI_API,
25
26
27
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
28
    ENABLE_MODEL_FILTER,
29
    MODEL_FILTER_LIST,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50

51
52
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
55
56
57
app.state.MODEL_CONFIG = [
    model.to_form() for model in Models.get_all_models_by_source("openai")
]
Timothy J. Baek's avatar
Timothy J. Baek committed
58

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
59
60

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
61
62
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
69
70
71
72
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
73

Timothy J. Baek's avatar
Timothy J. Baek committed
74
75
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
76
77


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
99
100


Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
103
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
104

Timothy J. Baek's avatar
Timothy J. Baek committed
105

Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
108
    await get_all_models()
109
110
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
111
112


Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
115
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117
118
119


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
120
121
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123


Timothy J. Baek's avatar
Timothy J. Baek committed
124
@app.post("/audio/speech")
125
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
    idx = None
    try:
128
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
132
133
134
135
136
137
138
139
140
141
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
142
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
143
        headers["Content-Type"] = "application/json"
144
145
146
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
147
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149
        try:
            r = requests.post(
150
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
153
154
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
155

Timothy J. Baek's avatar
Timothy J. Baek committed
156
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
157

Timothy J. Baek's avatar
Timothy J. Baek committed
158
159
160
161
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
162

Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
165

Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
168

Timothy J. Baek's avatar
Timothy J. Baek committed
169
        except Exception as e:
170
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
178
179
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
186
187


Timothy J. Baek's avatar
Timothy J. Baek committed
188
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
189
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
190
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
191
192
        if key != "":
            headers = {"Authorization": f"Bearer {key}"}
Timothy J. Baek's avatar
Timothy J. Baek committed
193
            async with aiohttp.ClientSession(timeout=timeout) as session:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
194
195
196
197
                async with session.get(url, headers=headers) as response:
                    return await response.json()
        else:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
    except Exception as e:
        # Handle connection error here
200
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
201
202
203
204
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
205
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
208
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
209
210
211
212
213
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
214
                    if "api.openai.com"
215
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
216
217
218
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
219

Timothy J. Baek's avatar
Timothy J. Baek committed
220
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222


Timothy J. Baek's avatar
Timothy J. Baek committed
223
async def get_all_models():
224
    log.info("get_all_models()")
225

226
    if (
227
228
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
229
    ) or not app.state.config.ENABLE_OPENAI_API:
230
231
232
        models = {"data": []}
    else:
        tasks = [
233
234
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
235
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
236

237
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
238
239
        log.info(f"get_all_models:responses() {responses}")

240
241
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
244
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
245
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
248
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
249
250
251
                        responses,
                    )
                )
252
253
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
254

255
256
257
        for model in models["data"]:
            add_custom_info_to_model(model)

258
        log.info(f"models: {models}")
259
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
260

261
262
263
264
265
    return models


def add_custom_info_to_model(model: dict):
    model["custom_info"] = next(
266
        (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None
267
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
268

Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
272
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
273
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
274
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
275
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
276
            if user.role == "user":
277
278
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
279
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
280
281
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
282
283
284
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
285
    else:
286
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
287
288
289

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
292
293
294
295
296
297
298
299
300
301
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
302
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
303
304
305
306
307
308
309
310
311
312
313
314
315
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
316
317


Timothy J. Baek's avatar
Timothy J. Baek committed
318
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
319
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
320
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
321

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
322
323
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
324
325
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
326
327
328
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
329
330
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

331
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
332
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
333
        if body.get("model") == "gpt-4-vision-preview":
334
335
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
336
            log.debug("Modified body_dict:", body)
337

Sakkus's avatar
Sakkus committed
338
        # Fix for ChatGPT calls failing because the num_ctx key is in body
339
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
340
341
342
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
343
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
344

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
345
346
347
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
348
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
349

350
351
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
354
355
356
357

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

358
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
359
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
360
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
361

Timothy J. Baek's avatar
Timothy J. Baek committed
362
363
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
364
365
366
367
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
368
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
369
370
371
372
373
374
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

375
376
377
378
379
380
381
382
383
384
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
385
    except Exception as e:
386
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
387
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
388
389
390
391
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
392
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
393
394
395
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
396
397
398
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )