Unverified Commit 4e7c6afd authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(kv-router): Increase JSON body to 8 MiB to allow 1M tokens (#8315)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent c69e19e8
...@@ -2374,6 +2374,7 @@ dependencies = [ ...@@ -2374,6 +2374,7 @@ dependencies = [
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tokio-util", "tokio-util",
"tower 0.5.3",
"tracing", "tracing",
"uuid", "uuid",
"validator", "validator",
......
...@@ -60,3 +60,4 @@ rstest_reuse = "0.7.0" ...@@ -60,3 +60,4 @@ rstest_reuse = "0.7.0"
serde_json = { workspace = true } serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt", "macros", "time", "test-util"] } tokio = { workspace = true, features = ["rt", "macros", "time", "test-util"] }
dynamo-tokens = { workspace = true } dynamo-tokens = { workspace = true }
tower = { version = "0.5", features = ["util"] }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::{DefaultBodyLimit, State};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
...@@ -17,6 +17,9 @@ use crate::protocols::{BlockHashOptions, LocalBlockHash, WorkerId, compute_block ...@@ -17,6 +17,9 @@ use crate::protocols::{BlockHashOptions, LocalBlockHash, WorkerId, compute_block
use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry}; use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry};
/// We need to fit one million tokens as JSON text, this should do it.
const QUERY_REQUEST_BODY_LIMIT_BYTES: usize = 8 * 1024 * 1024;
pub struct AppState { pub struct AppState {
pub registry: Arc<WorkerRegistry>, pub registry: Arc<WorkerRegistry>,
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
...@@ -401,7 +404,10 @@ pub fn create_router(state: Arc<AppState>) -> Router { ...@@ -401,7 +404,10 @@ pub fn create_router(state: Arc<AppState>) -> Router {
.route("/register", post(register)) .route("/register", post(register))
.route("/unregister", post(unregister)) .route("/unregister", post(unregister))
.route("/workers", get(list_workers)) .route("/workers", get(list_workers))
.route("/query", post(query)) .route(
"/query",
post(query).layer(DefaultBodyLimit::max(QUERY_REQUEST_BODY_LIMIT_BYTES)),
)
.route("/query_by_hash", post(query_by_hash)) .route("/query_by_hash", post(query_by_hash))
.route("/dump", get(dump_events)) .route("/dump", get(dump_events))
.route("/register_peer", post(register_peer)) .route("/register_peer", post(register_peer))
...@@ -428,3 +434,50 @@ pub fn create_router(state: Arc<AppState>) -> Router { ...@@ -428,3 +434,50 @@ pub fn create_router(state: Arc<AppState>) -> Router {
router router
} }
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode, header};
use tower::ServiceExt;
fn oversized_query_body() -> String {
let mut body = String::from(r#"{"token_ids":["#);
let mut first = true;
while body.len() <= QUERY_REQUEST_BODY_LIMIT_BYTES {
if !first {
body.push(',');
}
first = false;
body.push('0');
}
body.push_str(r#"],"model_name":"model"}"#);
body
}
#[tokio::test]
async fn query_rejects_request_bodies_over_limit() {
let app = create_router(Arc::new(AppState {
registry: Arc::new(WorkerRegistry::new(1)),
#[cfg(feature = "metrics")]
prom_registry: prometheus::Registry::new(),
}));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/query")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(oversized_query_body()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment