main.py 11.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
24
    ENABLE_OPENAI_API,
25
26
27
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
28
    ENABLE_MODEL_FILTER,
29
    MODEL_FILTER_LIST,
30
    AppConfig,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

50
51
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
52
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
53
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
54

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
55
56

app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
57
58
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
62

Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65
66
67
68
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
69

Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
72
73


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


class OpenAIConfigForm(BaseModel):
    enable_openai_api: Optional[bool] = None


@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
    app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
    return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}


Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
91
92


Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
99
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
100

Timothy J. Baek's avatar
Timothy J. Baek committed
101

Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
104
    await get_all_models()
105
106
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
111
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113
114
115


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
116
117
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
118
119


Timothy J. Baek's avatar
Timothy J. Baek committed
120
@app.post("/audio/speech")
121
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
    idx = None
    try:
124
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
127
128
129
130
131
132
133
134
135
136
137
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
138
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
139
        headers["Content-Type"] = "application/json"
140
141
142
        if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
            headers["HTTP-Referer"] = "https://openwebui.com/"
            headers["X-Title"] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
143
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
        try:
            r = requests.post(
146
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
153

Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
156
157
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
158

Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
161

Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
164

Timothy J. Baek's avatar
Timothy J. Baek committed
165
        except Exception as e:
166
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
169
170
171
172
173
174
175
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
178
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183


Timothy J. Baek's avatar
Timothy J. Baek committed
184
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
185
    timeout = aiohttp.ClientTimeout(total=5)
Timothy J. Baek's avatar
Timothy J. Baek committed
186
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
187
188
        if key != "":
            headers = {"Authorization": f"Bearer {key}"}
Timothy J. Baek's avatar
Timothy J. Baek committed
189
            async with aiohttp.ClientSession(timeout=timeout) as session:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
190
191
192
193
                async with session.get(url, headers=headers) as response:
                    return await response.json()
        else:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
    except Exception as e:
        # Handle connection error here
196
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
199
200
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
201
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
204
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
205
206
207
208
209
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
210
                    if "api.openai.com"
211
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
214
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
215

Timothy J. Baek's avatar
Timothy J. Baek committed
216
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218


Timothy J. Baek's avatar
Timothy J. Baek committed
219
async def get_all_models():
220
    log.info("get_all_models()")
221

222
    if (
223
224
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
225
    ) or not app.state.config.ENABLE_OPENAI_API:
226
227
228
        models = {"data": []}
    else:
        tasks = [
229
230
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
231
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
232

233
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
234
235
        log.info(f"get_all_models:responses() {responses}")

236
237
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
238
239
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
240
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
241
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
244
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
                        responses,
                    )
                )
248
249
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
250

251
        log.info(f"models: {models}")
252
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
253

254
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
255

Timothy J. Baek's avatar
Timothy J. Baek committed
256
257
258

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
259
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
260
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
261
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
262
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
263
            if user.role == "user":
264
265
                models["data"] = list(
                    filter(
266
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
267
268
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
272
    else:
273
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
274
275
276

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
277
278
279
280
281
282
283
284
285
286
287
288
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
289
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
292
293
294
295
296
297
298
299
300
301
302
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
303
304


Timothy J. Baek's avatar
Timothy J. Baek committed
305
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
306
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
307
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
308

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
309
310
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
311
312
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
313
314
315
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
316
317
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

318
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
319
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
320
        if body.get("model") == "gpt-4-vision-preview":
321
322
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
323
            log.debug("Modified body_dict:", body)
324

Sakkus's avatar
Sakkus committed
325
        # Fix for ChatGPT calls failing because the num_ctx key is in body
326
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
327
328
329
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
330
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
331

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
332
333
334
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
335
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
336

337
338
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
339
340
341
342
343
344

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

345
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
346
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
347
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
348

Timothy J. Baek's avatar
Timothy J. Baek committed
349
350
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
351
352
353
354
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
355
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
356
357
358
359
360
361
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

362
363
364
365
366
367
368
369
370
371
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
372
    except Exception as e:
373
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
374
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
375
376
377
378
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
379
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
380
381
382
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
383
384
385
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )