Unverified Commit 4e7c6afd authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(kv-router): Increase JSON body to 8 MiB to allow 1M tokens (#8315)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent c69e19e8
......@@ -2374,6 +2374,7 @@ dependencies = [
"thiserror 2.0.18",
"tokio",
"tokio-util",
"tower 0.5.3",
"tracing",
"uuid",
"validator",
......
......@@ -60,3 +60,4 @@ rstest_reuse = "0.7.0"
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt", "macros", "time", "test-util"] }
dynamo-tokens = { workspace = true }
tower = { version = "0.5", features = ["util"] }
......@@ -4,7 +4,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::State;
use axum::extract::{DefaultBodyLimit, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::{get, post};
......@@ -17,6 +17,9 @@ use crate::protocols::{BlockHashOptions, LocalBlockHash, WorkerId, compute_block
use super::registry::{IndexerKey, ListenerControlError, WorkerRegistry};
/// We need to fit one million tokens as JSON text, this should do it.
const QUERY_REQUEST_BODY_LIMIT_BYTES: usize = 8 * 1024 * 1024;
pub struct AppState {
pub registry: Arc<WorkerRegistry>,
#[cfg(feature = "metrics")]
......@@ -401,7 +404,10 @@ pub fn create_router(state: Arc<AppState>) -> Router {
.route("/register", post(register))
.route("/unregister", post(unregister))
.route("/workers", get(list_workers))
.route("/query", post(query))
.route(
"/query",
post(query).layer(DefaultBodyLimit::max(QUERY_REQUEST_BODY_LIMIT_BYTES)),
)
.route("/query_by_hash", post(query_by_hash))
.route("/dump", get(dump_events))
.route("/register_peer", post(register_peer))
......@@ -428,3 +434,50 @@ pub fn create_router(state: Arc<AppState>) -> Router {
router
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode, header};
use tower::ServiceExt;
fn oversized_query_body() -> String {
let mut body = String::from(r#"{"token_ids":["#);
let mut first = true;
while body.len() <= QUERY_REQUEST_BODY_LIMIT_BYTES {
if !first {
body.push(',');
}
first = false;
body.push('0');
}
body.push_str(r#"],"model_name":"model"}"#);
body
}
#[tokio::test]
async fn query_rejects_request_bodies_over_limit() {
let app = create_router(Arc::new(AppState {
registry: Arc::new(WorkerRegistry::new(1)),
#[cfg(feature = "metrics")]
prom_registry: prometheus::Registry::new(),
}));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/query")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(oversized_query_body()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment