"llama.cpp/gguf-py/tests/test_metadata.py" did not exist on "7aa90c0ea35f88a5ef227b773e5a9fe3a0fd7eb2"
main.py 8.85 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
from fastapi import FastAPI, Depends, HTTPException
2
3
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
4

5
import logging
6
from fastapi import FastAPI, Request, Depends, status, Response
Timothy J. Baek's avatar
Timothy J. Baek committed
7
from fastapi.responses import JSONResponse
8
9
10
11

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
12
import time
Timothy J. Baek's avatar
Timothy J. Baek committed
13
import requests
14

Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
17
from pydantic import BaseModel
from typing import Optional, List

18
from utils.utils import get_verified_user, get_current_user, get_admin_user
19
from config import SRC_LOG_LEVELS, ENV
20
from constants import MESSAGES
21
22
23

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25

Timothy J. Baek's avatar
Timothy J. Baek committed
26
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
27
28


29
30
import asyncio
import subprocess
Timothy J. Baek's avatar
Timothy J. Baek committed
31
import yaml
Timothy J. Baek's avatar
Timothy J. Baek committed
32

33
app = FastAPI()
Timothy J. Baek's avatar
Timothy J. Baek committed
34

35
origins = ["*"]
Timothy J. Baek's avatar
Timothy J. Baek committed
36

37
38
39
40
41
42
43
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
44

45

Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"

with open(LITELLM_CONFIG_DIR, "r") as file:
    litellm_config = yaml.safe_load(file)

app.state.CONFIG = litellm_config

53
54
55
# Global variable to store the subprocess reference
background_process = None

Timothy J. Baek's avatar
Timothy J. Baek committed
56

57
58
async def run_background_process(command):
    global background_process
Timothy J. Baek's avatar
Timothy J. Baek committed
59
    log.info("run_background_process")
60
61
62

    try:
        # Log the command to be executed
Timothy J. Baek's avatar
Timothy J. Baek committed
63
        log.info(f"Executing command: {command}")
64
65
66
67
68
        # Execute the command and create a subprocess
        process = await asyncio.create_subprocess_exec(
            *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        background_process = process
Timothy J. Baek's avatar
Timothy J. Baek committed
69
        log.info("Subprocess started successfully.")
70
71
72
73
74

        # Capture STDERR for debugging purposes
        stderr_output = await process.stderr.read()
        stderr_text = stderr_output.decode().strip()
        if stderr_text:
Timothy J. Baek's avatar
Timothy J. Baek committed
75
            log.info(f"Subprocess STDERR: {stderr_text}")
76

Timothy J. Baek's avatar
Timothy J. Baek committed
77
        # log.info output line by line
78
        async for line in process.stdout:
Timothy J. Baek's avatar
Timothy J. Baek committed
79
            log.info(line.decode().strip())
80
81
82

        # Wait for the process to finish
        returncode = await process.wait()
Timothy J. Baek's avatar
Timothy J. Baek committed
83
        log.info(f"Subprocess exited with return code {returncode}")
84
85
86
    except Exception as e:
        log.error(f"Failed to start subprocess: {e}")
        raise  # Optionally re-raise the exception if you want it to propagate
87
88
89


async def start_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
90
    log.info("start_litellm_background")
91
    # Command to run in the background
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
94
    command = (
        "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml"
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
95

96
    await run_background_process(command)
Timothy J. Baek's avatar
Timothy J. Baek committed
97
98


99
async def shutdown_litellm_background():
Timothy J. Baek's avatar
Timothy J. Baek committed
100
    log.info("shutdown_litellm_background")
101
102
103
104
    global background_process
    if background_process:
        background_process.terminate()
        await background_process.wait()  # Ensure the process has terminated
Timothy J. Baek's avatar
Timothy J. Baek committed
105
        log.info("Subprocess terminated")
106
        background_process = None
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
@app.on_event("startup")
110
async def startup_event():
Timothy J. Baek's avatar
Timothy J. Baek committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112
    log.info("startup_event")
Timothy J. Baek's avatar
Timothy J. Baek committed
113
    # TODO: Check config.yaml file and create one
114
    asyncio.create_task(start_litellm_background())
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116


117
118
119
120
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


121
122
123
124
125
@app.get("/")
async def get_status():
    return {"status": True}


Timothy J. Baek's avatar
Timothy J. Baek committed
126
async def restart_litellm():
127
128
129
130
131
132
133
134
135
136
    """
    Endpoint to restart the litellm background service.
    """
    log.info("Requested restart of litellm service.")
    try:
        # Shut down the existing process if it is running
        await shutdown_litellm_background()
        log.info("litellm service shutdown complete.")

        # Restart the background service
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138

        asyncio.create_task(start_litellm_background())
139
140
141
142
143
144
145
        log.info("litellm service restart complete.")

        return {
            "status": "success",
            "message": "litellm service restarted successfully.",
        }
    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
146
        log.info(f"Error restarting litellm service: {e}")
147
148
149
150
151
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
    return await restart_litellm()


@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
    return app.state.CONFIG


class LiteLLMConfigForm(BaseModel):
    general_settings: Optional[dict] = None
    litellm_settings: Optional[dict] = None
    model_list: Optional[List[dict]] = None
    router_settings: Optional[dict] = None


@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
    app.state.CONFIG = form_data.model_dump(exclude_none=True)

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()
    return app.state.CONFIG


Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
182
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
183
184
185
    while not background_process:
        await asyncio.sleep(0.1)

Timothy J. Baek's avatar
Timothy J. Baek committed
186
    url = "http://localhost:14365/v1"
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188
189
190
    r = None
    try:
        r = requests.request(method="GET", url=f"{url}/models")
        r.raise_for_status()
191

Timothy J. Baek's avatar
Timothy J. Baek committed
192
        data = r.json()
193

Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
196
197
198
199
200
201
        if app.state.MODEL_FILTER_ENABLED:
            if user and user.role == "user":
                data["data"] = list(
                    filter(
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
                        data["data"],
                    )
                )
202

Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
        return data
    except Exception as e:
205

Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
208
209
210
211
212
213
214
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']}"
            except:
                error_detail = f"External: {e}"
215

216
217
218
219
220
221
222
223
224
225
226
227
        return {
            "data": [
                {
                    "id": model["model_name"],
                    "object": "model",
                    "created": int(time.time()),
                    "owned_by": "openai",
                }
                for model in app.state.CONFIG["model_list"]
            ],
            "object": "list",
        }
228
229


230
231
232
233
234
235
236
237
238
239
240
241
242
243
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
    return {"data": app.state.CONFIG["model_list"]}


class AddLiteLLMModelForm(BaseModel):
    model_name: str
    litellm_params: dict


@app.post("/model/new")
async def add_model_to_config(
    form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
244
245
    # TODO: Validate model form

246
247
248
249
250
251
252
    app.state.CONFIG["model_list"].append(form_data.model_dump())

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()

253
    return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274


class DeleteLiteLLMModelForm(BaseModel):
    id: str


@app.post("/model/delete")
async def delete_model_from_config(
    form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
    app.state.CONFIG["model_list"] = [
        model
        for model in app.state.CONFIG["model_list"]
        if model["model_name"] != form_data.id
    ]

    with open(LITELLM_CONFIG_DIR, "w") as file:
        yaml.dump(app.state.CONFIG, file)

    await restart_litellm()

275
    return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
276
277


Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
280
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
    body = await request.body()
281

Timothy J. Baek's avatar
Timothy J. Baek committed
282
    url = "http://localhost:14365"
283

Timothy J. Baek's avatar
Timothy J. Baek committed
284
    target_url = f"{url}/{path}"
285

Timothy J. Baek's avatar
Timothy J. Baek committed
286
287
288
    headers = {}
    # headers["Authorization"] = f"Bearer {key}"
    headers["Content-Type"] = "application/json"
289

Timothy J. Baek's avatar
Timothy J. Baek committed
290
    r = None
291

Timothy J. Baek's avatar
Timothy J. Baek committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    try:
        r = requests.request(
            method=request.method,
            url=target_url,
            data=body,
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        # Check if response is SSE
        if "text/event-stream" in r.headers.get("Content-Type", ""):
            return StreamingResponse(
                r.iter_content(chunk_size=8192),
                status_code=r.status_code,
                headers=dict(r.headers),
            )
        else:
            response_data = r.json()
            return response_data
    except Exception as e:
        log.exception(e)
        error_detail = "Open WebUI: Server Connection Error"
        if r is not None:
            try:
                res = r.json()
                if "error" in res:
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
            except:
                error_detail = f"External: {e}"

        raise HTTPException(
            status_code=r.status_code if r else 500, detail=error_detail
        )