main.py 9.76 KB
Newer Older
1
2
import sys

Timothy J. Baek's avatar
Timothy J. Baek committed
3
from fastapi import FastAPI, Depends, HTTPException
4
5
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
6

7
import logging
8
from fastapi import FastAPI, Request, Depends, status, Response
Timothy J. Baek's avatar
Timothy J. Baek committed
9
from fastapi.responses import JSONResponse
10
11
12
13

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
14
import time
Timothy J. Baek's avatar
Timothy J. Baek committed
15
import requests
16

17
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
from typing import Optional, List

20
from utils.utils import get_verified_user, get_current_user, get_admin_user
21
from config import SRC_LOG_LEVELS, ENV
22
from constants import MESSAGES
23

24
25
import os

26
27
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
Timothy J. Baek's avatar
Timothy J. Baek committed
28

29

30
from config import (
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    ENABLE_MODEL_FILTER,
32
33
34
    MODEL_FILTER_LIST,
    DATA_DIR,
    LITELLM_PROXY_PORT,
35
    LITELLM_PROXY_HOST,
36
)
37

38
from litellm.utils import get_llm_provider
39

40
41
import asyncio
import subprocess
Timothy J. Baek's avatar
Timothy J. Baek committed
42
import yaml
Timothy J. Baek's avatar
Timothy J. Baek committed
43

44
app = FastAPI()
Timothy J. Baek's avatar
Timothy J. Baek committed
45

46
origins = ["*"]
Timothy J. Baek's avatar
Timothy J. Baek committed
47

48
49
50
51
52
53
54
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56

Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
62
63
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"

with open(LITELLM_CONFIG_DIR, "r") as file:
    litellm_config = yaml.safe_load(file)

app.state.CONFIG = litellm_config

64
65
66
# Global variable to store the subprocess reference
background_process = None

67
68
69
70
71
72
73
CONFLICT_ENV_VARS = [
    # Uvicorn uses PORT, so LiteLLM might use it as well
    "PORT",
    # LiteLLM uses DATABASE_URL for Prisma connections
    "DATABASE_URL",
]

Timothy J. Baek's avatar
Timothy J. Baek committed
74

75
76
async def run_background_process(command):
    global background_process
Timothy J. Baek's avatar
Timothy J. Baek committed
77
    log.info("run_background_process")
78
79
80

    try:
        # Log the command to be executed
Timothy J. Baek's avatar
Timothy J. Baek committed
81
        log.info(f"Executing command: {command}")
82
83
        # Filter environment variables known to conflict with litellm
        env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
84
85
        # Execute the command and create a subprocess
        process = await asyncio.create_subprocess_exec(
86
            *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
87
88
        )
        background_process = process
Timothy J. Baek's avatar
Timothy J. Baek committed
89
        log.info("Subprocess started successfully.")
90
91
92
93
94

        # Capture STDERR for debugging purposes
        stderr_output = await process.stderr.read()
        stderr_text = stderr_output.decode().strip()
        if stderr_text:
Timothy J. Baek's avatar
Timothy J. Baek committed
95
            log.info(f"Subprocess STDERR: {stderr_text}")
96

Timothy J. Baek's avatar
Timothy J. Baek committed
97
        # log.info output line by line
98
        async for line in process.stdout:
Timothy J. Baek's avatar
Timothy J. Baek committed
99
            log.info(line.decode().strip())
100
101
102

        # Wait for the process to finish
        returncode = await process.wait()
Timothy J. Baek's avatar
Timothy J. Baek committed
103
        log.info(f"Subprocess exited with return code {returncode}")
104
105
106
    except Exception as e:
        log.error(f"Failed to start subprocess: {e}")
        raise  # Optionally re-raise the exception if you want it to propagate
107
108
109


async def start_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
110
    log.info("start_litellm_background")
111
    # Command to run in the background
112
113
114
115
116
117
118
119
120
121
122
    command = [
        "litellm",
        "--port",
        str(LITELLM_PROXY_PORT),
        "--host",
        LITELLM_PROXY_HOST,
        "--telemetry",
        "False",
        "--config",
        LITELLM_CONFIG_DIR,
    ]
Timothy J. Baek's avatar
Timothy J. Baek committed
123

124
    await run_background_process(command)
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126


127
async def shutdown_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
128
    log.info("shutdown_litellm_background")
129
130
131
132
    global background_process
    if background_process:
        background_process.terminate()
        await background_process.wait()  # Ensure the process has terminated
Timothy J. Baek's avatar
Timothy J. Baek committed
133
        log.info("Subprocess terminated")
134
        background_process = None
135
136


Timothy J. Baek's avatar
Timothy J. Baek committed
137
@app.on_event("startup")
138
async def startup_event():
Timothy J. Baek's avatar
Timothy J. Baek committed
139
    log.info("startup_event")
Timothy J. Baek's avatar
Timothy J. Baek committed
140
    # TODO: Check config.yaml file and create one
141
    asyncio.create_task(start_litellm_background())
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143


Timothy J. Baek's avatar
Timothy J. Baek committed
144
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
145
146
147
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


148
149
150
151
152
@app.get("/")
async def get_status():
    return {"status": True}


Timothy J. Baek's avatar
Timothy J. Baek committed
153
async def restart_litellm():
154
155
156
157
158
159
160
161
162
163
    """
    Endpoint to restart the litellm background service.
    """
    log.info("Requested restart of litellm service.")
    try:
        # Shut down the existing process if it is running
        await shutdown_litellm_background()
        log.info("litellm service shutdown complete.")

        # Restart the background service
Timothy J. Baek's avatar
Timothy J. Baek committed
164
165

        asyncio.create_task(start_litellm_background())
166
167
168
169
170
171
172
        log.info("litellm service restart complete.")

        return {
            "status": "success",
            "message": "litellm service restarted successfully.",
        }
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
173
        log.info(f"Error restarting litellm service: {e}")
174
175
176
177
178
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
    return await restart_litellm()


@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return app.state.CONFIG


class LiteLLMConfigForm(BaseModel):
    general_settings: Optional[dict] = None
    litellm_settings: Optional[dict] = None
    model_list: Optional[List[dict]] = None
    router_settings: Optional[dict] = None

195
196
    model_config = ConfigDict(protected_namespaces=())

Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
199
200
201
202
203
204
205
206
207
208

@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
    app.state.CONFIG = form_data.model_dump(exclude_none=True)

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()
    return app.state.CONFIG


Timothy J. Baek's avatar
Timothy J. Baek committed
209
210
211
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
212
213
214
    while not background_process:
        await asyncio.sleep(0.1)

215
    url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
Timothy J. Baek's avatar
Timothy J. Baek committed
216
217
218
219
    r = None
    try:
        r = requests.request(method="GET", url=f"{url}/models")
        r.raise_for_status()
220

Timothy J. Baek's avatar
Timothy J. Baek committed
221
        data = r.json()
222

Timothy J. Baek's avatar
Timothy J. Baek committed
223
        if app.state.ENABLE_MODEL_FILTER:
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
            if user and user.role == "user":
                data["data"] = list(
                    filter(
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
                        data["data"],
                    )
                )
231

Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
        return data
    except Exception as e:
234

Timothy J. Baek's avatar
Timothy J. Baek committed
235
236
237
238
239
240
241
242
243
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"
244

245
246
247
248
249
250
251
252
253
254
255
256
        return {
            "data": [
                {
                    "id": model["model_name"],
                    "object": "model",
                    "created": int(time.time()),
                    "owned_by": "openai",
                }
                for model in app.state.CONFIG["model_list"]
            ],
            "object": "list",
        }
257
258


259
260
261
262
263
264
265
266
267
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
    return {"data": app.state.CONFIG["model_list"]}


class AddLiteLLMModelForm(BaseModel):
    model_name: str
    litellm_params: dict

268
269
    model_config = ConfigDict(protected_namespaces=())

270
271
272
273
274

@app.post("/model/new")
async def add_model_to_config(
    form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
275
276
277
    try:
        get_llm_provider(model=form_data.model_name)
        app.state.CONFIG["model_list"].append(form_data.model_dump())
278

279
280
        with open(LITELLM_CONFIG_DIR, "w") as file:
            yaml.dump(app.state.CONFIG, file)
281

282
        await restart_litellm()
283

284
285
286
287
288
289
        return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
    except Exception as e:
        print(e)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310


class DeleteLiteLLMModelForm(BaseModel):
    id: str


@app.post("/model/delete")
async def delete_model_from_config(
    form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
    app.state.CONFIG["model_list"] = [
        model
        for model in app.state.CONFIG["model_list"]
        if model["model_name"] != form_data.id
    ]

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()

311
    return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
312
313


Timothy J. Baek's avatar
Timothy J. Baek committed
314
315
316
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    body = await request.body()
317

318
    url = f"http://localhost:{LITELLM_PROXY_PORT}"
319

Timothy J. Baek's avatar
Timothy J. Baek committed
320
    target_url = f"{url}/{path}"
321

Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
324
    headers = {}
    # headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
325

Timothy J. Baek's avatar
Timothy J. Baek committed
326
    r = None
327

Timothy J. Baek's avatar
Timothy J. Baek committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
            data=body,
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )