main.py 9.22 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
from fastapi import FastAPI, Depends, HTTPException
2
3
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
4

5
import logging
6
from fastapi import FastAPI, Request, Depends, status, Response
Timothy J. Baek's avatar
Timothy J. Baek committed
7
from fastapi.responses import JSONResponse
8
9
10
11

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
12
import time
Timothy J. Baek's avatar
Timothy J. Baek committed
13
import requests
14

15
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
from typing import Optional, List

18
from utils.utils import get_verified_user, get_current_user, get_admin_user
19
from config import SRC_LOG_LEVELS, ENV
20
from constants import MESSAGES
21
22
23

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25

Timothy J. Baek's avatar
Timothy J. Baek committed
26
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
27

28
from litellm.utils import get_llm_provider
29

30
31
import asyncio
import subprocess
Timothy J. Baek's avatar
Timothy J. Baek committed
32
import yaml
Timothy J. Baek's avatar
Timothy J. Baek committed
33

34
app = FastAPI()
Timothy J. Baek's avatar
Timothy J. Baek committed
35

36
origins = ["*"]
Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
41
42
43
44
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
45

46

Timothy J. Baek's avatar
Timothy J. Baek committed
47
48
49
50
51
52
53
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"

with open(LITELLM_CONFIG_DIR, "r") as file:
    litellm_config = yaml.safe_load(file)

app.state.CONFIG = litellm_config

54
55
56
# Global variable to store the subprocess reference
background_process = None

Timothy J. Baek's avatar
Timothy J. Baek committed
57

58
59
async def run_background_process(command):
    global background_process
Timothy J. Baek's avatar
Timothy J. Baek committed
60
    log.info("run_background_process")
61
62
63

    try:
        # Log the command to be executed
Timothy J. Baek's avatar
Timothy J. Baek committed
64
        log.info(f"Executing command: {command}")
65
66
67
68
69
        # Execute the command and create a subprocess
        process = await asyncio.create_subprocess_exec(
            *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        background_process = process
Timothy J. Baek's avatar
Timothy J. Baek committed
70
        log.info("Subprocess started successfully.")
71
72
73
74
75

        # Capture STDERR for debugging purposes
        stderr_output = await process.stderr.read()
        stderr_text = stderr_output.decode().strip()
        if stderr_text:
Timothy J. Baek's avatar
Timothy J. Baek committed
76
            log.info(f"Subprocess STDERR: {stderr_text}")
77

Timothy J. Baek's avatar
Timothy J. Baek committed
78
        # log.info output line by line
79
        async for line in process.stdout:
Timothy J. Baek's avatar
Timothy J. Baek committed
80
            log.info(line.decode().strip())
81
82
83

        # Wait for the process to finish
        returncode = await process.wait()
Timothy J. Baek's avatar
Timothy J. Baek committed
84
        log.info(f"Subprocess exited with return code {returncode}")
85
86
87
    except Exception as e:
        log.error(f"Failed to start subprocess: {e}")
        raise  # Optionally re-raise the exception if you want it to propagate
88
89
90


async def start_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
91
    log.info("start_litellm_background")
92
    # Command to run in the background
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
95
    command = (
        "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml"
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
96

97
    await run_background_process(command)
Timothy J. Baek's avatar
Timothy J. Baek committed
98
99


100
async def shutdown_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
101
    log.info("shutdown_litellm_background")
102
103
104
105
    global background_process
    if background_process:
        background_process.terminate()
        await background_process.wait()  # Ensure the process has terminated
Timothy J. Baek's avatar
Timothy J. Baek committed
106
        log.info("Subprocess terminated")
107
        background_process = None
108
109


Timothy J. Baek's avatar
Timothy J. Baek committed
110
@app.on_event("startup")
111
async def startup_event():
Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
    log.info("startup_event")
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    # TODO: Check config.yaml file and create one
115
    asyncio.create_task(start_litellm_background())
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117


118
119
120
121
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


122
123
124
125
126
@app.get("/")
async def get_status():
    return {"status": True}


Timothy J. Baek's avatar
Timothy J. Baek committed
127
async def restart_litellm():
128
129
130
131
132
133
134
135
136
137
    """
    Endpoint to restart the litellm background service.
    """
    log.info("Requested restart of litellm service.")
    try:
        # Shut down the existing process if it is running
        await shutdown_litellm_background()
        log.info("litellm service shutdown complete.")

        # Restart the background service
Timothy J. Baek's avatar
Timothy J. Baek committed
138
139

        asyncio.create_task(start_litellm_background())
140
141
142
143
144
145
146
        log.info("litellm service restart complete.")

        return {
            "status": "success",
            "message": "litellm service restarted successfully.",
        }
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
147
        log.info(f"Error restarting litellm service: {e}")
148
149
150
151
152
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
    return await restart_litellm()


@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return app.state.CONFIG


class LiteLLMConfigForm(BaseModel):
    general_settings: Optional[dict] = None
    litellm_settings: Optional[dict] = None
    model_list: Optional[List[dict]] = None
    router_settings: Optional[dict] = None

169
170
    model_config = ConfigDict(protected_namespaces=())

Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
175
176
177
178
179
180
181
182

@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
    app.state.CONFIG = form_data.model_dump(exclude_none=True)

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()
    return app.state.CONFIG


Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
186
187
188
    while not background_process:
        await asyncio.sleep(0.1)

Timothy J. Baek's avatar
Timothy J. Baek committed
189
    url = "http://localhost:14365/v1"
Timothy J. Baek's avatar
Timothy J. Baek committed
190
191
192
193
    r = None
    try:
        r = requests.request(method="GET", url=f"{url}/models")
        r.raise_for_status()
194

Timothy J. Baek's avatar
Timothy J. Baek committed
195
        data = r.json()
196

Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
199
200
201
202
203
204
        if app.state.MODEL_FILTER_ENABLED:
            if user and user.role == "user":
                data["data"] = list(
                    filter(
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
                        data["data"],
                    )
                )
205

Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
        return data
    except Exception as e:
208

Timothy J. Baek's avatar
Timothy J. Baek committed
209
210
211
212
213
214
215
216
217
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"
218

219
220
221
222
223
224
225
226
227
228
229
230
        return {
            "data": [
                {
                    "id": model["model_name"],
                    "object": "model",
                    "created": int(time.time()),
                    "owned_by": "openai",
                }
                for model in app.state.CONFIG["model_list"]
            ],
            "object": "list",
        }
231
232


233
234
235
236
237
238
239
240
241
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
    return {"data": app.state.CONFIG["model_list"]}


class AddLiteLLMModelForm(BaseModel):
    model_name: str
    litellm_params: dict

242
243
    model_config = ConfigDict(protected_namespaces=())

244
245
246
247
248

@app.post("/model/new")
async def add_model_to_config(
    form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
249
250
251
    try:
        get_llm_provider(model=form_data.model_name)
        app.state.CONFIG["model_list"].append(form_data.model_dump())
252

253
254
        with open(LITELLM_CONFIG_DIR, "w") as file:
            yaml.dump(app.state.CONFIG, file)
255

256
        await restart_litellm()
257

258
259
260
261
262
263
        return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
    except Exception as e:
        print(e)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284


class DeleteLiteLLMModelForm(BaseModel):
    id: str


@app.post("/model/delete")
async def delete_model_from_config(
    form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
    app.state.CONFIG["model_list"] = [
        model
        for model in app.state.CONFIG["model_list"]
        if model["model_name"] != form_data.id
    ]

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()

285
    return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
286
287


Timothy J. Baek's avatar
Timothy J. Baek committed
288
289
290
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    body = await request.body()
291

Timothy J. Baek's avatar
Timothy J. Baek committed
292
    url = "http://localhost:14365"
293

Timothy J. Baek's avatar
Timothy J. Baek committed
294
    target_url = f"{url}/{path}"
295

Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
298
    headers = {}
    # headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
299

Timothy J. Baek's avatar
Timothy J. Baek committed
300
    r = None
301

Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
            data=body,
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )