main.py 3.64 KB
Newer Older
1
2
3
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
Timothy J. Baek's avatar
Timothy J. Baek committed
4

5
import logging
6
from fastapi import FastAPI, Request, Depends, status, Response
Timothy J. Baek's avatar
Timothy J. Baek committed
7
from fastapi.responses import JSONResponse
8
9
10
11
12

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json

Timothy J. Baek's avatar
Timothy J. Baek committed
13
from utils.utils import get_http_authorization_cred, get_current_user
14
15
16
17
from config import SRC_LOG_LEVELS, ENV

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
Timothy J. Baek's avatar
Timothy J. Baek committed
18

19
20
21
22
23
24
25

from config import (
    MODEL_FILTER_ENABLED,
    MODEL_FILTER_LIST,
)


26
27
import asyncio
import subprocess
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


30
app = FastAPI()
Timothy J. Baek's avatar
Timothy J. Baek committed
31

32
origins = ["*"]
Timothy J. Baek's avatar
Timothy J. Baek committed
33

34
35
36
37
38
39
40
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
41

42
43
44
45
46
47
48
49
50
51
52
53

async def run_background_process(command):
    process = await asyncio.create_subprocess_exec(
        *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    return process


async def start_litellm_background():
    # Command to run in the background
    command = "litellm --config ./data/litellm/config.yaml"
    await run_background_process(command)
Timothy J. Baek's avatar
Timothy J. Baek committed
54
55
56


@app.on_event("startup")
57
58
async def startup_event():
    asyncio.create_task(start_litellm_background())
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60


61
62
63
64
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
    auth_header = request.headers.get("Authorization", "")
68
    request.state.user = None
Timothy J. Baek's avatar
Timothy J. Baek committed
69

70
71
    try:
        user = get_current_user(get_http_authorization_cred(auth_header))
Self Denial's avatar
Self Denial committed
72
        log.debug(f"user: {user}")
73
74
75
        request.state.user = user
    except Exception as e:
        return JSONResponse(status_code=400, content={"detail": str(e)})
Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78

    response = await call_next(request)
    return response
79
80


81
82
83
84
85
@app.get("/")
async def get_status():
    return {"status": True}


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:

        response = await call_next(request)
        user = request.state.user

        if "/models" in request.url.path:
            if isinstance(response, StreamingResponse):
                # Read the content of the streaming response
                body = b""
                async for chunk in response.body_iterator:
                    body += chunk

                data = json.loads(body.decode("utf-8"))

                if app.state.MODEL_FILTER_ENABLED:
                    if user and user.role == "user":
                        data["data"] = list(
                            filter(
                                lambda model: model["id"]
                                in app.state.MODEL_FILTER_LIST,
                                data["data"],
                            )
                        )

Timothy J. Baek's avatar
Timothy J. Baek committed
113
                # Modified Flag
114
115
116
117
118
119
120
                data["modified"] = True
                return JSONResponse(content=data)

        return response


app.add_middleware(ModifyModelsResponseMiddleware)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


# from litellm.proxy.proxy_server import ProxyConfig, initialize
# from litellm.proxy.proxy_server import app

# proxy_config = ProxyConfig()


# async def config():
#     router, model_list, general_settings = await proxy_config.load_config(
#         router=None, config_file_path="./data/litellm/config.yaml"
#     )

#     await initialize(config="./data/litellm/config.yaml", telemetry=False)


# async def startup():
#     await config()


# @app.on_event("startup")
# async def on_startup():
#     await startup()