main.py 10.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
24
25
26
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    ENABLE_MODEL_FILTER,
28
    MODEL_FILTER_LIST,
29
    AppConfig,
30
)
Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
35

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
36

37
38
39
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
44
45
46
47
48
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

49
50
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
51
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
52
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
53

54
55
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
59

Timothy J. Baek's avatar
Timothy J. Baek committed
60
61
62
63
64
65
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
69
70


Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
73
74


Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78


Timothy J. Baek's avatar
Timothy J. Baek committed
79
80
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
81
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
82

Timothy J. Baek's avatar
Timothy J. Baek committed
83

Timothy J. Baek's avatar
Timothy J. Baek committed
84
85
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
86
    await get_all_models()
87
88
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90


Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
93
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
96
97


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
98
99
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
100
101


Timothy J. Baek's avatar
Timothy J. Baek committed
102
@app.post("/audio/speech")
103
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
    idx = None
    try:
106
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
109
110
111
112
113
114
115
116
117
118
119
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
120
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
121
122
        headers["Content-Type"] = "application/json"

Timothy J. Baek's avatar
Timothy J. Baek committed
123
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
        try:
            r = requests.post(
126
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
129
130
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
131

Timothy J. Baek's avatar
Timothy J. Baek committed
132
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
133

Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
136
137
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
138

Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
141

Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
144

Timothy J. Baek's avatar
Timothy J. Baek committed
145
        except Exception as e:
146
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
151
152
153
154
155
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
161

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163


Timothy J. Baek's avatar
Timothy J. Baek committed
164
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
165
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
169
170
171
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
172
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
175
176
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
177
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
183
184
185
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
186
                    if "api.openai.com"
187
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
188
189
190
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
191

Timothy J. Baek's avatar
Timothy J. Baek committed
192
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
193
194


Timothy J. Baek's avatar
Timothy J. Baek committed
195
async def get_all_models():
196
    log.info("get_all_models()")
197

198
    if (
199
200
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
201
    ):
202
203
204
        models = {"data": []}
    else:
        tasks = [
205
206
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
207
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
208

209
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
        log.info(f"get_all_models:responses() {responses}")

212
213
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
216
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
217
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
220
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222
223
                        responses,
                    )
                )
224
225
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
226

227
        log.info(f"models: {models}")
228
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
229

230
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
231

Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
235
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
236
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
237
        models = await get_all_models()
238
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
239
            if user.role == "user":
240
241
                models["data"] = list(
                    filter(
242
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
243
244
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
248
    else:
249
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
250
251
252

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
255
256
257
258
259
260
261
262
263
264
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
265
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
266
267
268
269
270
271
272
273
274
275
276
277
278
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280


Timothy J. Baek's avatar
Timothy J. Baek committed
281
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
282
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
283
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
284

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
286
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
287
288
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
289
290
291
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
292
293
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

294
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
295
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
296
        if body.get("model") == "gpt-4-vision-preview":
297
298
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
299
            log.debug("Modified body_dict:", body)
300

Sakkus's avatar
Sakkus committed
301
        # Fix for ChatGPT calls failing because the num_ctx key is in body
302
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
303
304
305
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
306
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
307

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
308
309
310
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
311
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
312

313
314
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
315
316
317
318
319
320

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

321
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
322
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
323
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
324

Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
327
328
329
330
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
331
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
334
335
336
337
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

338
339
340
341
342
343
344
345
346
347
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
348
    except Exception as e:
349
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
350
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
351
352
353
354
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
355
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
356
357
358
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
359
360
361
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )