main.py 9.29 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
from fastapi import FastAPI, Depends, HTTPException
2
3
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
4

5
import logging
6
from fastapi import FastAPI, Request, Depends, status, Response
Timothy J. Baek's avatar
Timothy J. Baek committed
7
from fastapi.responses import JSONResponse
8
9
10
11

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
12
import time
Timothy J. Baek's avatar
Timothy J. Baek committed
13
import requests
14

15
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
from typing import Optional, List

18
from utils.utils import get_verified_user, get_current_user, get_admin_user
19
from config import SRC_LOG_LEVELS, ENV
20
from constants import MESSAGES
21
22
23

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25

26
27
28
29
30
31
from config import (
    MODEL_FILTER_ENABLED,
    MODEL_FILTER_LIST,
    DATA_DIR,
    LITELLM_PROXY_PORT,
)
32

33
from litellm.utils import get_llm_provider
34

35
36
import asyncio
import subprocess
Timothy J. Baek's avatar
Timothy J. Baek committed
37
import yaml
Timothy J. Baek's avatar
Timothy J. Baek committed
38

39
app = FastAPI()
Timothy J. Baek's avatar
Timothy J. Baek committed
40

41
origins = ["*"]
Timothy J. Baek's avatar
Timothy J. Baek committed
42

43
44
45
46
47
48
49
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
50

51

Timothy J. Baek's avatar
Timothy J. Baek committed
52
53
54
55
56
57
58
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"

with open(LITELLM_CONFIG_DIR, "r") as file:
    litellm_config = yaml.safe_load(file)

app.state.CONFIG = litellm_config

59
60
61
# Global variable to store the subprocess reference
background_process = None

Timothy J. Baek's avatar
Timothy J. Baek committed
62

63
64
async def run_background_process(command):
    global background_process
Timothy J. Baek's avatar
Timothy J. Baek committed
65
    log.info("run_background_process")
66
67
68

    try:
        # Log the command to be executed
Timothy J. Baek's avatar
Timothy J. Baek committed
69
        log.info(f"Executing command: {command}")
70
71
72
73
74
        # Execute the command and create a subprocess
        process = await asyncio.create_subprocess_exec(
            *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        background_process = process
Timothy J. Baek's avatar
Timothy J. Baek committed
75
        log.info("Subprocess started successfully.")
76
77
78
79
80

        # Capture STDERR for debugging purposes
        stderr_output = await process.stderr.read()
        stderr_text = stderr_output.decode().strip()
        if stderr_text:
Timothy J. Baek's avatar
Timothy J. Baek committed
81
            log.info(f"Subprocess STDERR: {stderr_text}")
82

Timothy J. Baek's avatar
Timothy J. Baek committed
83
        # log.info output line by line
84
        async for line in process.stdout:
Timothy J. Baek's avatar
Timothy J. Baek committed
85
            log.info(line.decode().strip())
86
87
88

        # Wait for the process to finish
        returncode = await process.wait()
Timothy J. Baek's avatar
Timothy J. Baek committed
89
        log.info(f"Subprocess exited with return code {returncode}")
90
91
92
    except Exception as e:
        log.error(f"Failed to start subprocess: {e}")
        raise  # Optionally re-raise the exception if you want it to propagate
93
94
95


async def start_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
96
    log.info("start_litellm_background")
97
    # Command to run in the background
98
    command = f"litellm --port {LITELLM_PROXY_PORT} --telemetry False --config ./data/litellm/config.yaml"
Timothy J. Baek's avatar
Timothy J. Baek committed
99

100
    await run_background_process(command)
Timothy J. Baek's avatar
Timothy J. Baek committed
101
102


103
async def shutdown_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
104
    log.info("shutdown_litellm_background")
105
106
107
108
    global background_process
    if background_process:
        background_process.terminate()
        await background_process.wait()  # Ensure the process has terminated
Timothy J. Baek's avatar
Timothy J. Baek committed
109
        log.info("Subprocess terminated")
110
        background_process = None
111
112


Timothy J. Baek's avatar
Timothy J. Baek committed
113
@app.on_event("startup")
114
async def startup_event():
Timothy J. Baek's avatar
Timothy J. Baek committed
115
    log.info("startup_event")
Timothy J. Baek's avatar
Timothy J. Baek committed
116
    # TODO: Check config.yaml file and create one
117
    asyncio.create_task(start_litellm_background())
Timothy J. Baek's avatar
Timothy J. Baek committed
118
119


120
121
122
123
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


124
125
126
127
128
@app.get("/")
async def get_status():
    return {"status": True}


Timothy J. Baek's avatar
Timothy J. Baek committed
129
async def restart_litellm():
130
131
132
133
134
135
136
137
138
139
    """
    Endpoint to restart the litellm background service.
    """
    log.info("Requested restart of litellm service.")
    try:
        # Shut down the existing process if it is running
        await shutdown_litellm_background()
        log.info("litellm service shutdown complete.")

        # Restart the background service
Timothy J. Baek's avatar
Timothy J. Baek committed
140
141

        asyncio.create_task(start_litellm_background())
142
143
144
145
146
147
148
        log.info("litellm service restart complete.")

        return {
            "status": "success",
            "message": "litellm service restarted successfully.",
        }
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
149
        log.info(f"Error restarting litellm service: {e}")
150
151
152
153
154
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
    return await restart_litellm()


@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return app.state.CONFIG


class LiteLLMConfigForm(BaseModel):
    general_settings: Optional[dict] = None
    litellm_settings: Optional[dict] = None
    model_list: Optional[List[dict]] = None
    router_settings: Optional[dict] = None

171
172
    model_config = ConfigDict(protected_namespaces=())

Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
175
176
177
178
179
180
181
182
183
184

@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
    app.state.CONFIG = form_data.model_dump(exclude_none=True)

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()
    return app.state.CONFIG


Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
188
189
190
    while not background_process:
        await asyncio.sleep(0.1)

191
    url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193
194
195
    r = None
    try:
        r = requests.request(method="GET", url=f"{url}/models")
        r.raise_for_status()
196

Timothy J. Baek's avatar
Timothy J. Baek committed
197
        data = r.json()
198

Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
202
203
204
205
206
        if app.state.MODEL_FILTER_ENABLED:
            if user and user.role == "user":
                data["data"] = list(
                    filter(
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
                        data["data"],
                    )
                )
207

Timothy J. Baek's avatar
Timothy J. Baek committed
208
209
        return data
    except Exception as e:
210

Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
213
214
215
216
217
218
219
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"
220

221
222
223
224
225
226
227
228
229
230
231
232
        return {
            "data": [
                {
                    "id": model["model_name"],
                    "object": "model",
                    "created": int(time.time()),
                    "owned_by": "openai",
                }
                for model in app.state.CONFIG["model_list"]
            ],
            "object": "list",
        }
233
234


235
236
237
238
239
240
241
242
243
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
    return {"data": app.state.CONFIG["model_list"]}


class AddLiteLLMModelForm(BaseModel):
    model_name: str
    litellm_params: dict

244
245
    model_config = ConfigDict(protected_namespaces=())

246
247
248
249
250

@app.post("/model/new")
async def add_model_to_config(
    form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
251
252
253
    try:
        get_llm_provider(model=form_data.model_name)
        app.state.CONFIG["model_list"].append(form_data.model_dump())
254

255
256
        with open(LITELLM_CONFIG_DIR, "w") as file:
            yaml.dump(app.state.CONFIG, file)
257

258
        await restart_litellm()
259

260
261
262
263
264
265
        return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
    except Exception as e:
        print(e)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286


class DeleteLiteLLMModelForm(BaseModel):
    id: str


@app.post("/model/delete")
async def delete_model_from_config(
    form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
    app.state.CONFIG["model_list"] = [
        model
        for model in app.state.CONFIG["model_list"]
        if model["model_name"] != form_data.id
    ]

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()

287
    return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
288
289


Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
292
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    body = await request.body()
293

294
    url = f"http://localhost:{LITELLM_PROXY_PORT}"
295

Timothy J. Baek's avatar
Timothy J. Baek committed
296
    target_url = f"{url}/{path}"
297

Timothy J. Baek's avatar
Timothy J. Baek committed
298
299
300
    headers = {}
    # headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
301

Timothy J. Baek's avatar
Timothy J. Baek committed
302
    r = None
303

Timothy J. Baek's avatar
Timothy J. Baek committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
            data=body,
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )