"...composable_kernel.git" did not exist on "00c1016e4979b2c121bf76ff20d4c2b7c29845a1"
Commit 3aa6b0fe authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

fix: model filter issue

parent 3890ea14
...@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file: ...@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file) litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.ENABLE = ENABLE_LITELLM app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config app.state.CONFIG = litellm_config
...@@ -151,10 +155,6 @@ async def shutdown_litellm_background(): ...@@ -151,10 +155,6 @@ async def shutdown_litellm_background():
background_process = None background_process = None
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {"status": True}
......
...@@ -64,8 +64,8 @@ app.add_middleware( ...@@ -64,8 +64,8 @@ app.add_middleware(
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -178,11 +178,12 @@ async def get_ollama_tags( ...@@ -178,11 +178,12 @@ async def get_ollama_tags(
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST, lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"], models["models"],
) )
) )
...@@ -1046,11 +1047,12 @@ async def get_openai_models( ...@@ -1046,11 +1047,12 @@ async def get_openai_models(
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST, lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"], models["models"],
) )
) )
......
...@@ -47,10 +47,11 @@ app.add_middleware( ...@@ -47,10 +47,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
...@@ -259,11 +260,11 @@ async def get_all_models(): ...@@ -259,11 +260,11 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["data"] = list( models["data"] = list(
filter( filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST, lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"], models["data"],
) )
) )
......
...@@ -292,11 +292,11 @@ async def update_model_filter_config( ...@@ -292,11 +292,11 @@ async def update_model_filter_config(
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models app.state.config.MODEL_FILTER_LIST = form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment