main.py 11.9 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
    ENABLE_OPENAI_API,
25
26
27
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
28
    ENABLE_MODEL_FILTER,
29
    MODEL_FILTER_LIST,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50

51
52
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
55

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
56
57

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
58
59
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
60
61
62

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
63

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
69
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
70

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
73
74


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93


Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97


Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
100
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
101

Timothy J. Baek's avatar
Timothy J. Baek committed
102

Timothy J. Baek's avatar
Timothy J. Baek committed
103
104
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
105
    await get_all_models()
106
107
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109


Timothy J. Baek's avatar
Timothy J. Baek committed
110
111
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
112
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
117
118
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120


Timothy J. Baek's avatar
Timothy J. Baek committed
121
@app.post("/audio/speech")
122
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
    idx = None
    try:
125
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
136
137
138
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
139
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
140
        headers["Content-Type"] = "application/json"
141
142
143
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
144
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
        try:
            r = requests.post(
147
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149
150
151
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
152

Timothy J. Baek's avatar
Timothy J. Baek committed
153
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
154

Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
157
158
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
159

Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
162

Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
165

Timothy J. Baek's avatar
Timothy J. Baek committed
166
        except Exception as e:
167
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
170
171
172
173
174
175
176
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184


Timothy J. Baek's avatar
Timothy J. Baek committed
185
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
186
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
188
189
        if key != "":
            headers = {"Authorization": f"Bearer {key}"}
Timothy J. Baek's avatar
Timothy J. Baek committed
190
            async with aiohttp.ClientSession(timeout=timeout) as session:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
191
192
193
194
                async with session.get(url, headers=headers) as response:
                    return await response.json()
        else:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196
    except Exception as e:
        # Handle connection error here
197
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
200
201
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
202
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
205
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
208
209
210
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
211
                    if "api.openai.com"
212
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
216

Timothy J. Baek's avatar
Timothy J. Baek committed
217
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219


Timothy J. Baek's avatar
Timothy J. Baek committed
220
async def get_all_models():
221
    log.info("get_all_models()")
222

223
    if (
224
225
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
226
    ) or not app.state.config.ENABLE_OPENAI_API:
227
228
229
        models = {"data": []}
    else:
        tasks = [
230
231
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
232
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
233

234
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
235
236
        log.info(f"get_all_models:responses() {responses}")

237
238
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
239
240
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
241
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
242
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
243
244
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
245
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
248
                        responses,
                    )
                )
249
250
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
251

252
        log.info(f"models: {models}")
253
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
254

255
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
256

Timothy J. Baek's avatar
Timothy J. Baek committed
257
258
259

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
260
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
261
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
262
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
263
        if app.state.config.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
264
            if user.role == "user":
265
266
                models["data"] = list(
                    filter(
Timothy J. Baek's avatar
Timothy J. Baek committed
267
                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
268
269
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
273
    else:
274
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
277

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
280
281
282
283
284
285
286
287
288
289
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
290
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
293
294
295
296
297
298
299
300
301
302
303
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
304
305


Timothy J. Baek's avatar
Timothy J. Baek committed
306
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
307
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
308
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
309
    pipeline = False
Timothy J. Baek's avatar
Timothy J. Baek committed
310

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
311
312
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
313
314
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
315
316
317
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
318
319
320
321
322
323
324
325
326
        model = app.state.MODELS[body.get("model")]

        idx = model["urlIdx"]

        if "pipeline" in model:
            pipeline = model.get("pipeline")

        if pipeline:
            body["user"] = {"name": user.name, "id": user.id}
Timothy J. Baek's avatar
Timothy J. Baek committed
327

328
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
329
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
330
        if body.get("model") == "gpt-4-vision-preview":
331
332
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
333
            log.debug("Modified body_dict:", body)
334

Sakkus's avatar
Sakkus committed
335
        # Fix for ChatGPT calls failing because the num_ctx key is in body
336
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
337
338
339
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
340
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
341

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
342
343
344
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
345
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
346

347
348
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350
351
352
353
354

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

355
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
356
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
357
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
358

Timothy J. Baek's avatar
Timothy J. Baek committed
359
360
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
361
362
363
364
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
365
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
366
367
368
369
370
371
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

372
373
374
375
376
377
378
379
380
381
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
382
    except Exception as e:
383
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
384
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
385
386
387
388
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
389
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
390
391
392
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
393
394
395
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )