busy_threshold.rs 7.52 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
//! HTTP endpoint for dynamically getting/setting the busy thresholds per model.
5
//!
6
7
8
9
//! The busy thresholds control when workers are marked as "busy" based on their
//! KV cache block utilization and prefill token utilization. When all workers
//! for a model exceed their thresholds, new requests are rejected with a 503
//! Service Unavailable response.
10
11
12
13
14
//!
//! ## Endpoints
//!
//! ### POST /busy_threshold
//!
15
//! Get or set a model's busy thresholds.
16
//!
17
//! **Set thresholds:**
18
19
//! ```json
//! // Request
20
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
21
//! // Response
22
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
23
24
//! ```
//!
25
//! **Get thresholds (omit thresholds):**
26
27
28
29
//! ```json
//! // Request
//! {"model": "llama-3-70b"}
//! // Response (if configured)
30
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
31
//! // Response (if not configured)
32
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
33
34
35
36
37
38
39
40
//! ```
//!
//! ### GET /busy_threshold
//!
//! List all configured busy thresholds.
//!
//! ```json
//! // Response
41
//! {"thresholds": [{"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
42
43
44
45
46
//! ```

use super::{RouteDoc, service_v2};
use axum::{
    Json, Router,
47
    extract::Request,
48
    http::{Method, StatusCode},
49
50
    middleware::Next,
    response::{IntoResponse, Response},
51
52
53
54
55
    routing::{get, post},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;

56
/// Request body for getting or setting busy thresholds.
57
///
58
59
/// - If thresholds are provided: sets/creates the thresholds and returns the new values
/// - If thresholds are null/omitted: returns the existing thresholds if any
60
61
62
63
#[derive(Debug, Deserialize)]
pub struct BusyThresholdRequest {
    /// The model name
    pub model: String,
64
65
66
67
    /// The active decode blocks threshold value (0.0 to 1.0), or null to just get the current value
    pub active_decode_blocks_threshold: Option<f64>,
    /// The active prefill tokens threshold value (literal token count), or null to just get the current value
    pub active_prefill_tokens_threshold: Option<u64>,
68
69
70
71
72
73
74
}

/// Response for a threshold operation
#[derive(Debug, Serialize)]
pub struct BusyThresholdResponse {
    /// The model name
    pub model: String,
75
76
77
78
    /// The active decode blocks threshold value (null if no threshold is configured)
    pub active_decode_blocks_threshold: Option<f64>,
    /// The active prefill tokens threshold value (null if no threshold is configured)
    pub active_prefill_tokens_threshold: Option<u64>,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
}

/// Response for listing all thresholds
#[derive(Debug, Serialize)]
pub struct ListBusyThresholdsResponse {
    /// List of model thresholds
    pub thresholds: Vec<BusyThresholdResponse>,
}

/// Error response
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
    pub error: String,
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/// Middleware to convert 422 Unprocessable Entity responses (from JSON deserialization errors)
/// to JSON format instead of text/plain.
async fn json_error_middleware(request: Request, next: Next) -> Response {
    let response = next.run(request).await;

    if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
        let (_parts, body) = response.into_parts();
        let body_bytes = axum::body::to_bytes(body, usize::MAX)
            .await
            .unwrap_or_default();
        let error_message = String::from_utf8_lossy(&body_bytes).to_string();
        (
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(serde_json::json!(ErrorResponse {
                error: error_message,
            })),
        )
            .into_response()
    } else {
        response
    }
}

117
118
119
120
121
122
123
124
125
126
127
128
129
130
pub fn busy_threshold_router(
    state: Arc<service_v2::State>,
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let base_path = path.unwrap_or_else(|| "/busy_threshold".to_string());

    let docs: Vec<RouteDoc> = vec![
        RouteDoc::new(Method::POST, &base_path),
        RouteDoc::new(Method::GET, &base_path),
    ];

    let router = Router::new()
        .route(&base_path, post(busy_threshold_handler))
        .route(&base_path, get(list_busy_thresholds_handler))
131
        .layer(axum::middleware::from_fn(json_error_middleware))
132
133
134
135
136
137
138
139
140
        .with_state(state);

    (docs, router)
}

async fn busy_threshold_handler(
    axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
    Json(request): Json<BusyThresholdRequest>,
) -> impl IntoResponse {
141
142
    // Validate active decode blocks threshold range if provided (must be 0.0-1.0)
    if let Some(threshold) = request.active_decode_blocks_threshold
143
144
145
146
147
        && !(0.0..=1.0).contains(&threshold)
    {
        return (
            StatusCode::BAD_REQUEST,
            Json(serde_json::json!(ErrorResponse {
148
149
150
151
                error: format!(
                    "active_decode_blocks_threshold must be between 0.0 and 1.0, got {}",
                    threshold
                ),
152
153
154
155
156
157
            })),
        );
    }

    let manager = state.manager();

158
159
160
161
162
    // Get or set the thresholds via the model's worker monitor
    let active_decode_blocks_threshold = manager
        .active_decode_blocks_threshold(&request.model, request.active_decode_blocks_threshold);
    let active_prefill_tokens_threshold = manager
        .active_prefill_tokens_threshold(&request.model, request.active_prefill_tokens_threshold);
163
164

    // If trying to SET but model has no monitor, return 404
165
166
167
168
169
170
    let is_setting = request.active_decode_blocks_threshold.is_some()
        || request.active_prefill_tokens_threshold.is_some();
    if is_setting
        && active_decode_blocks_threshold.is_none()
        && active_prefill_tokens_threshold.is_none()
    {
171
172
173
174
175
176
177
178
179
180
181
        return (
            StatusCode::NOT_FOUND,
            Json(serde_json::json!(ErrorResponse {
                error: format!(
                    "Model '{}' not found. Thresholds can only be set for discovered models.",
                    request.model
                ),
            })),
        );
    }

182
    if is_setting {
183
184
        tracing::info!(
            model = %request.model,
185
186
187
            active_decode_blocks_threshold = ?active_decode_blocks_threshold,
            active_prefill_tokens_threshold = ?active_prefill_tokens_threshold,
            "Updated busy thresholds"
188
189
190
191
192
193
194
        );
    }

    (
        StatusCode::OK,
        Json(serde_json::json!(BusyThresholdResponse {
            model: request.model,
195
196
            active_decode_blocks_threshold,
            active_prefill_tokens_threshold,
197
198
199
200
201
202
203
204
205
206
207
208
209
        })),
    )
}

async fn list_busy_thresholds_handler(
    axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
    let manager = state.manager();
    let thresholds = manager.list_busy_thresholds();

    let response = ListBusyThresholdsResponse {
        thresholds: thresholds
            .into_iter()
210
211
212
213
214
215
216
217
218
            .map(
                |(model, active_decode_blocks_threshold, active_prefill_tokens_threshold)| {
                    BusyThresholdResponse {
                        model,
                        active_decode_blocks_threshold: Some(active_decode_blocks_threshold),
                        active_prefill_tokens_threshold: Some(active_prefill_tokens_threshold),
                    }
                },
            )
219
220
221
222
223
            .collect(),
    };

    Json(serde_json::json!(response))
}