Commit 144581a7 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

refac: get_sorted_pipelines()

parent 7ffd75b9
...@@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware) ...@@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
################################## ##################################
def filter_pipeline(payload, user): def get_sorted_filters(model_id):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [ filters = [
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
...@@ -782,6 +780,13 @@ def filter_pipeline(payload, user): ...@@ -782,6 +780,13 @@ def filter_pipeline(payload, user):
) )
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id)
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
...@@ -814,19 +819,12 @@ def filter_pipeline(payload, user): ...@@ -814,19 +819,12 @@ def filter_pipeline(payload, user):
print(f"Connection error: {e}") print(f"Connection error: {e}")
if r is not None: if r is not None:
try: res = r.json()
res = r.json()
except:
pass
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(r.status_code, res["detail"])
else: if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
pass del payload["task"]
if "pipeline" not in app.state.MODELS[model_id]:
if "task" in payload:
del payload["task"]
return payload return payload
...@@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
filters = [ sorted_filters = get_sorted_filters(model_id)
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model: if "pipeline" in model:
sorted_filters = [model] + sorted_filters sorted_filters = [model] + sorted_filters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment