main.py 10.9 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5

import requests
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
import aiohttp
import asyncio
Timothy J. Baek's avatar
Timothy J. Baek committed
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
from pydantic import BaseModel

Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
22
from config import (
23
    SRC_LOG_LEVELS,
24
25
26
    OPENAI_API_BASE_URLS,
    OPENAI_API_KEYS,
    CACHE_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    ENABLE_MODEL_FILTER,
28
    MODEL_FILTER_LIST,
29
30
    config_set,
    config_get,
31
)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
from typing import List, Optional

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36

import hashlib
from pathlib import Path
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
45
46
47
48
49
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
51
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
52

Timothy J. Baek's avatar
Timothy J. Baek committed
53
54
55
56
57
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS

app.state.MODELS = {}

Timothy J. Baek's avatar
Timothy J. Baek committed
58

Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61
62
63
64
@app.middleware("http")
async def check_url(request: Request, call_next):
    if len(app.state.MODELS) == 0:
        await get_all_models()
    else:
        pass
Timothy J. Baek's avatar
Timothy J. Baek committed
65

Timothy J. Baek's avatar
Timothy J. Baek committed
66
67
    response = await call_next(request)
    return response
Timothy J. Baek's avatar
Timothy J. Baek committed
68
69


Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
class UrlsUpdateForm(BaseModel):
    urls: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
72
73


Timothy J. Baek's avatar
Timothy J. Baek committed
74
75
class KeysUpdateForm(BaseModel):
    keys: List[str]
Timothy J. Baek's avatar
Timothy J. Baek committed
76
77


Timothy J. Baek's avatar
Timothy J. Baek committed
78
79
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
80
    return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
81

Timothy J. Baek's avatar
Timothy J. Baek committed
82

Timothy J. Baek's avatar
Timothy J. Baek committed
83
84
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
85
    await get_all_models()
86
87
    config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls)
    return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89


Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
92
    return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
95
96


@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
97
98
    config_set(app.state.OPENAI_API_KEYS, form_data.keys)
    return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
Timothy J. Baek's avatar
Timothy J. Baek committed
99
100


Timothy J. Baek's avatar
Timothy J. Baek committed
101
@app.post("/audio/speech")
102
async def speech(request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
103
104
    idx = None
    try:
105
106
107
        idx = config_get(app.state.OPENAI_API_BASE_URLS).index(
            "https://api.openai.com/v1"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
110
111
112
113
114
115
116
117
118
119
120
        body = await request.body()
        name = hashlib.sha256(body).hexdigest()

        SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
        SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

        # Check if the file already exists in the cache
        if file_path.is_file():
            return FileResponse(file_path)

        headers = {}
121
122
123
        headers["Authorization"] = (
            f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
        headers["Content-Type"] = "application/json"

Timothy J. Baek's avatar
Timothy J. Baek committed
126
        r = None
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
        try:
            r = requests.post(
129
                url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech",
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
132
133
                data=body,
                headers=headers,
                stream=True,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
134

Timothy J. Baek's avatar
Timothy J. Baek committed
135
            r.raise_for_status()
Timothy J. Baek's avatar
Timothy J. Baek committed
136

Timothy J. Baek's avatar
Timothy J. Baek committed
137
138
139
140
            # Save the streaming content to a file
            with open(file_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
141

Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
            with open(file_body_path, "w") as f:
                json.dump(json.loads(body.decode("utf-8")), f)
Timothy J. Baek's avatar
Timothy J. Baek committed
144

Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
            # Return the saved file
            return FileResponse(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
147

Timothy J. Baek's avatar
Timothy J. Baek committed
148
        except Exception as e:
149
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
152
153
154
155
156
157
158
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
161
            raise HTTPException(
                status_code=r.status_code if r else 500, detail=error_detail
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164

    except ValueError:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166


Timothy J. Baek's avatar
Timothy J. Baek committed
167
async def fetch_url(url, key):
Timothy J. Baek's avatar
Timothy J. Baek committed
168
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
171
172
173
174
        headers = {"Authorization": f"Bearer {key}"}
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as response:
                return await response.json()
    except Exception as e:
        # Handle connection error here
175
        log.error(f"Connection error: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
178
179
        return None


def merge_models_lists(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
180
    log.info(f"merge_models_lists {model_lists}")
Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
183
    merged_list = []

    for idx, models in enumerate(model_lists):
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
186
187
188
        if models is not None and "error" not in models:
            merged_list.extend(
                [
                    {**model, "urlIdx": idx}
                    for model in models
189
190
                    if "api.openai.com"
                    not in config_get(app.state.OPENAI_API_BASE_URLS)[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
191
192
193
                    or "gpt" in model["id"]
                ]
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
194

Timothy J. Baek's avatar
Timothy J. Baek committed
195
    return merged_list
Timothy J. Baek's avatar
Timothy J. Baek committed
196
197


Timothy J. Baek's avatar
Timothy J. Baek committed
198
async def get_all_models():
199
    log.info("get_all_models()")
200

201
202
203
204
    if (
        len(config_get(app.state.OPENAI_API_KEYS)) == 1
        and config_get(app.state.OPENAI_API_KEYS)[0] == ""
    ):
205
206
207
        models = {"data": []}
    else:
        tasks = [
208
209
            fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx])
            for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS))
210
        ]
Timothy J. Baek's avatar
Timothy J. Baek committed
211

212
        responses = await asyncio.gather(*tasks)
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
        log.info(f"get_all_models:responses() {responses}")

215
216
        models = {
            "data": merge_models_lists(
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
                list(
                    map(
Timothy J. Baek's avatar
Timothy J. Baek committed
219
                        lambda response: (
Timothy J. Baek's avatar
Timothy J. Baek committed
220
                            response["data"]
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222
                            if (response and "data" in response)
                            else (response if isinstance(response, list) else None)
Timothy J. Baek's avatar
Timothy J. Baek committed
223
                        ),
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
                        responses,
                    )
                )
227
228
            )
        }
Timothy J. Baek's avatar
Timothy J. Baek committed
229

230
        log.info(f"models: {models}")
231
        app.state.MODELS = {model["id"]: model for model in models["data"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
232

233
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
234

Timothy J. Baek's avatar
Timothy J. Baek committed
235
236
237

@app.get("/models")
@app.get("/models/{url_idx}")
Timothy J. Baek's avatar
Timothy J. Baek committed
238
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
239
    if url_idx == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
240
        models = await get_all_models()
241
        if config_get(app.state.ENABLE_MODEL_FILTER):
Timothy J. Baek's avatar
Timothy J. Baek committed
242
            if user.role == "user":
243
244
                models["data"] = list(
                    filter(
245
246
                        lambda model: model["id"]
                        in config_get(app.state.MODEL_FILTER_LIST),
247
248
                        models["data"],
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
249
250
251
                )
                return models
        return models
Timothy J. Baek's avatar
Timothy J. Baek committed
252
    else:
253
        url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
254
255
256

        r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
257
258
259
260
261
262
263
264
265
266
267
268
        try:
            r = requests.request(method="GET", url=f"{url}/models")
            r.raise_for_status()

            response_data = r.json()
            if "api.openai.com" in url:
                response_data["data"] = list(
                    filter(lambda model: "gpt" in model["id"], response_data["data"])
                )

            return response_data
        except Exception as e:
269
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
273
274
275
276
277
278
279
280
281
282
            error_detail = "Open WebUI: Server Connection Error"
            if r is not None:
                try:
                    res = r.json()
                    if "error" in res:
                        error_detail = f"External: {res['error']}"
                except:
                    error_detail = f"External: {e}"

            raise HTTPException(
                status_code=r.status_code if r else 500,
                detail=error_detail,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284


Timothy J. Baek's avatar
Timothy J. Baek committed
285
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
286
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
287
    idx = 0
Timothy J. Baek's avatar
Timothy J. Baek committed
288

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
289
290
    body = await request.body()
    # TODO: Remove below after gpt-4-vision fix from Open AI
291
292
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
    try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
293
294
295
        body = body.decode("utf-8")
        body = json.loads(body)

Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
        idx = app.state.MODELS[body.get("model")]["urlIdx"]

298
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
299
        # This is a workaround until OpenAI fixes the issue with this model
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
300
        if body.get("model") == "gpt-4-vision-preview":
301
302
            if "max_tokens" not in body:
                body["max_tokens"] = 4000
303
            log.debug("Modified body_dict:", body)
304

Sakkus's avatar
Sakkus committed
305
        # Fix for ChatGPT calls failing because the num_ctx key is in body
306
        if "num_ctx" in body:
Sakkus's avatar
Sakkus committed
307
308
309
            # If 'num_ctx' is in the dictionary, delete it
            # Leaving it there generates an error with the
            # OpenAI API (Feb 2024)
310
            del body["num_ctx"]
Sakkus's avatar
Sakkus committed
311

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
312
313
314
        # Convert the modified body back to JSON
        body = json.dumps(body)
    except json.JSONDecodeError as e:
315
        log.error("Error loading request body into a dictionary:", e)
Timothy J. Baek's avatar
Timothy J. Baek committed
316

317
318
    url = config_get(app.state.OPENAI_API_BASE_URLS)[idx]
    key = config_get(app.state.OPENAI_API_KEYS)[idx]
Timothy J. Baek's avatar
Timothy J. Baek committed
319
320
321
322
323
324

    target_url = f"{url}/{path}"

    if key == "":
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

325
    headers = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
326
    headers["Authorization"] = f"Bearer {key}"
Timothy J. Baek's avatar
Timothy J. Baek committed
327
    headers["Content-Type"] = "application/json"
Timothy J. Baek's avatar
Timothy J. Baek committed
328

Timothy J. Baek's avatar
Timothy J. Baek committed
329
330
    r = None

Timothy J. Baek's avatar
Timothy J. Baek committed
331
332
333
334
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
335
            data=body,
Timothy J. Baek's avatar
Timothy J. Baek committed
336
337
338
339
340
341
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

342
343
344
345
346
347
348
349
350
351
        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
Timothy J. Baek's avatar
Timothy J. Baek committed
352
    except Exception as e:
353
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
354
        error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
355
356
357
358
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
Timothy J. Baek's avatar
Timothy J. Baek committed
359
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
Timothy J. Baek's avatar
Timothy J. Baek committed
360
361
362
            except:
                error_detail = f"External: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
363
364
365
        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )