main.py 10.9 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
24
25
26
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    ENABLE_MODEL_FILTER,
28
    MODEL_FILTER_LIST,
29
    AppConfig,
30
)
Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
35

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
36

37
38
39
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
44
45
46
47
48
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

49
50
app.state.config = AppConfig()

Timothy J. Baek's avatar
Timothy J. Baek committed
51
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
52
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
53

54
55
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
59

Timothy J. Baek's avatar
Timothy J. Baek committed
60
61
62
63
64
65
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
66

Timothy J. Baek's avatar
Timothy J. Baek committed
67
68
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
69
70


Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
73
74


Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78


Timothy J. Baek's avatar
Timothy J. Baek committed
79
80
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
81
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
82

Timothy J. Baek's avatar
Timothy J. Baek committed
83

Timothy J. Baek's avatar
Timothy J. Baek committed
84
85
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
86
    await get_all_models()
87
88
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90


Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
93
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
96
97


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
98
99
    app.state.config.OPENAI_API_KEYS = form_data.keys
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
Timothy J. Baek's avatar
Timothy J. Baek committed
100
101


Timothy J. Baek's avatar
Timothy J. Baek committed
102
@app.post("/audio/speech")
103
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
    idx = None
    try:
106
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
109
110
111
112
113
114
115
116
117
118
119
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
120
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
Timothy J. Baek's avatar
Timothy J. Baek committed
121
        headers["Content-Type"] = "application/json"
122
123
124
        if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]:
            headers['HTTP-Referer'] = "https://openwebui.com/"
            headers['X-Title'] = "Open WebUI"
Timothy J. Baek's avatar
Timothy J. Baek committed
125
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
        try:
            r = requests.post(
128
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
132
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
133

Timothy J. Baek's avatar
Timothy J. Baek committed
134
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
135

Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
138
139
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
140

Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
143

Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
146

Timothy J. Baek's avatar
Timothy J. Baek committed
147
        except Exception as e:
148
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
151
152
153
154
155
156
157
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
158
159
160
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
163

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
164
165


Timothy J. Baek's avatar
Timothy J. Baek committed
166
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
167
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
170
171
172
173
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
174
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
177
178
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
179
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185
186
187
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
188
                    if "api.openai.com"
189
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
190
191
192
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
193

Timothy J. Baek's avatar
Timothy J. Baek committed
194
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196


Timothy J. Baek's avatar
Timothy J. Baek committed
197
async def get_all_models():
198
    log.info("get_all_models()")
199

200
    if (
201
202
        len(app.state.config.OPENAI_API_KEYS) == 1
        and app.state.config.OPENAI_API_KEYS[0] == ""
203
    ):
204
205
206
        models = {"data": []}
    else:
        tasks = [
207
208
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
209
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
210

211
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
        log.info(f"get_all_models:responses() {responses}")

214
215
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
216
217
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
218
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
219
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
220
221
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
222
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
223
224
225
                        responses,
                    )
                )
226
227
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
228

229
        log.info(f"models: {models}")
230
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
231

232
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
233

Timothy J. Baek's avatar
Timothy J. Baek committed
234
235
236

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
237
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
238
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
239
        models = await get_all_models()
Timothy J. Baek's avatar
Timothy J. Baek committed
240
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
241
            if user.role == "user":
242
243
                models["data"] = list(
                    filter(
244
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
245
246
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
249
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    else:
251
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
252
253
254

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
257
258
259
260
261
262
263
264
265
266
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
267
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269
270
271
272
273
274
275
276
277
278
279
280
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
281
282


Timothy J. Baek's avatar
Timothy J. Baek committed
283
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
284
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
285
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
286

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
287
288
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
289
290
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
291
292
293
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

296
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
297
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
298
        if body.get("model") == "gpt-4-vision-preview":
299
300
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
301
            log.debug("Modified body_dict:", body)
302

Sakkus's avatar
Sakkus committed
303
        # Fix for ChatGPT calls failing because the num_ctx key is in body
304
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
305
306
307
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
308
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
309

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
310
311
312
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
313
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
314

315
316
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
    key = app.state.config.OPENAI_API_KEYS[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
320
321
322

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

323
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
324
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
325
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
326

Timothy J. Baek's avatar
Timothy J. Baek committed
327
328
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
329
330
331
332
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
333
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
334
335
336
337
338
339
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

340
341
342
343
344
345
346
347
348
349
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
350
    except Exception as e:
351
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
352
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
353
354
355
356
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
357
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
358
359
360
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
361
362
363
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )