Unverified Commit d0d364e3 authored by jain-ria's avatar jain-ria Committed by GitHub
Browse files

feat: add endpoint to clear all kv blocks in vllm v1 (#1384)

parent a4600ba1
...@@ -29,12 +29,14 @@ pub async fn run( ...@@ -29,12 +29,14 @@ pub async fn run(
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let http_service = service_v2::HttpService::builder() let http_service = service_v2::HttpService::builder()
.port(flags.http_port) .port(flags.http_port)
.enable_chat_endpoints(true) .enable_chat_endpoints(true)
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true) .enable_embeddings_endpoints(true)
.with_request_template(template) .with_request_template(template)
.runtime(Some(Arc::new(distributed_runtime)))
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic => { EngineConfig::Dynamic => {
......
...@@ -116,7 +116,7 @@ class StatLoggerFactory: ...@@ -116,7 +116,7 @@ class StatLoggerFactory:
class RequestHandler: class RequestHandler:
""" """
Request handler for the generate endpoint Request handler for the generate and clear_kv_blocks endpoints.
""" """
def __init__(self, component, engine, default_sampling_params): def __init__(self, component, engine, default_sampling_params):
...@@ -124,6 +124,13 @@ class RequestHandler: ...@@ -124,6 +124,13 @@ class RequestHandler:
self.engine_client = engine self.engine_client = engine
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
async def clear_kv_blocks(self, request=None):
try:
self.engine_client.reset_prefix_cache()
yield {"status": "success", "message": "KV cache cleared"}
except Exception as e:
yield {"status": "error", "message": str(e)}
async def generate(self, request): async def generate(self, request):
request_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex)
...@@ -175,13 +182,16 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -175,13 +182,16 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
await register_llm( await register_llm(
ModelType.Backend, ModelType.Backend,
endpoint, generate_endpoint,
config.model_path, config.model_path,
config.model_name, config.model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
...@@ -249,16 +259,21 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -249,16 +259,21 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info("VllmWorker has been initialized") logger.info("VllmWorker has been initialized")
zmq_config = ZmqKvEventPublisherConfig( zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size worker_id=generate_endpoint.lease_id(), kv_block_size=engine_args.block_size
) )
_ = ZmqKvEventPublisher(component=component, config=zmq_config) _ = ZmqKvEventPublisher(component=component, config=zmq_config)
handler = RequestHandler(component, engine_client, default_sampling_params) handler = RequestHandler(component, engine_client, default_sampling_params)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) try:
# after the lease is revoked await asyncio.gather(
await endpoint.serve_endpoint(handler.generate) generate_endpoint.serve_endpoint(handler.generate),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
def cmd_line_args(): def cmd_line_args():
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
mod openai; mod openai;
pub mod clear_kv_blocks;
pub mod error; pub mod error;
pub mod health; pub mod health;
pub mod metrics; pub mod metrics;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{service_v2, RouteDoc};
use axum::{http::Method, response::IntoResponse, routing::post, Json, Router};
use serde_json::json;
use std::sync::Arc;
use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt};
pub fn clear_kv_blocks_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or_else(|| "/clear_kv_blocks".to_string());
let docs: Vec<RouteDoc> = vec![RouteDoc::new(Method::POST, &path)];
let router = Router::new()
.route(&path, post(clear_kv_blocks_handler))
.with_state(state);
(docs, router)
}
async fn clear_kv_blocks_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
let model_entries = state.manager().get_model_entries();
// if there are no active workers
if model_entries.is_empty() {
return Json(serde_json::json!({
"message": "No active worker groups found"
}));
}
let distributed = match state.runtime() {
Some(runtime) => runtime,
None => {
return Json(serde_json::json!({
"message": "Failed to create distributed runtime",
}));
}
};
let mut cleared_workers = Vec::new();
let mut failed_workers = Vec::new();
// update cleared and failed workers
let mut add_worker_result = |success: bool,
name: String,
status: &str,
ns: &str,
comp: &str,
message: Option<String>| {
let mut result = json!({
"name": name,
"endpoint": format!("{}/{}/clear_kv_blocks", ns, comp),
"status": status,
});
if success {
if let Some(m) = message {
result["response"] = json!(m);
}
cleared_workers.push(result);
} else {
if let Some(m) = message {
result["error"] = json!(m);
}
failed_workers.push(result);
}
};
// create client for each model entry
for entry in &model_entries {
let namespace = &entry.endpoint.namespace;
let component = &entry.endpoint.component;
let entry_name = entry.name.to_string();
tracing::debug!("Processing worker group: {}/{}", namespace, component);
let namespace_obj = match distributed.namespace(namespace) {
Ok(ns) => ns,
Err(e) => {
add_worker_result(
false,
entry_name,
"Failed to get namespace",
namespace,
component,
Some(e.to_string()),
);
continue;
}
};
let component_obj = match namespace_obj.component(component) {
Ok(comp) => comp,
Err(e) => {
add_worker_result(
false,
entry_name,
"Failed to get component",
namespace,
component,
Some(e.to_string()),
);
continue;
}
};
let endpoint: dynamo_runtime::component::Endpoint =
component_obj.endpoint("clear_kv_blocks");
let client = match endpoint.client().await {
Ok(c) => c,
Err(e) => {
add_worker_result(
false,
entry_name,
"Failed to get client",
namespace,
component,
Some(e.to_string()),
);
continue;
}
};
let router = match PushRouter::<(), serde_json::Value>::from_client(
client.clone(),
Default::default(),
)
.await
{
Ok(r) => r,
Err(e) => {
add_worker_result(
false,
entry_name,
"Failed to create router",
namespace,
component,
Some(e.to_string()),
);
continue;
}
};
let instances = match component_obj.list_instances().await {
Ok(instances) => instances,
Err(e) => {
add_worker_result(
false,
entry_name,
"Failed to get instances for worker group",
namespace,
component,
Some(e.to_string()),
);
continue;
}
};
if instances.is_empty() {
add_worker_result(
false,
entry_name,
"No instances found for worker group",
namespace,
component,
None,
);
continue;
}
let instances_filtered = instances
.clone()
.into_iter()
.filter(|instance| instance.endpoint == "clear_kv_blocks")
.collect::<Vec<_>>();
if instances_filtered.is_empty() {
let found_endpoints: Vec<String> = instances
.iter()
.map(|instance| instance.endpoint.clone())
.collect();
add_worker_result(
false,
entry_name,
&format!(
"Worker group doesn't support clear_kv_blocks. Supported endpoints: {}",
found_endpoints.join(", ")
),
namespace,
component,
None,
);
continue;
}
for instance in &instances_filtered {
let instance_name = format!("{}-instance-{}", entry.name, instance.id());
match router.round_robin(().into()).await {
Ok(mut stream) => match stream.next().await {
Some(response) => {
add_worker_result(
true,
instance_name,
"Successfully cleared kv blocks for instance",
namespace,
component,
Some(response.to_string()),
);
}
None => {
add_worker_result(
false,
instance_name,
"No response from instance",
namespace,
component,
None,
);
}
},
Err(e) => {
add_worker_result(
false,
instance_name,
"Failed to send request for instance",
namespace,
component,
Some(e.to_string()),
);
}
}
}
}
Json(serde_json::json!({
"cleared_workers": cleared_workers,
"failed_workers": failed_workers
}))
}
...@@ -11,6 +11,7 @@ use crate::discovery::ModelManager; ...@@ -11,6 +11,7 @@ use crate::discovery::ModelManager;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -18,6 +19,7 @@ use tokio_util::sync::CancellationToken; ...@@ -18,6 +19,7 @@ use tokio_util::sync::CancellationToken;
pub struct State { pub struct State {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
runtime: Option<Arc<DistributedRuntime>>,
} }
impl State { impl State {
...@@ -25,6 +27,15 @@ impl State { ...@@ -25,6 +27,15 @@ impl State {
Self { Self {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
runtime: None,
}
}
pub fn with_runtime(manager: Arc<ModelManager>, runtime: Arc<DistributedRuntime>) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
runtime: Some(runtime),
} }
} }
...@@ -41,6 +52,11 @@ impl State { ...@@ -41,6 +52,11 @@ impl State {
self.manager.clone() self.manager.clone()
} }
/// Get the DistributedRuntime if available
pub fn runtime(&self) -> Option<&DistributedRuntime> {
self.runtime.as_ref().map(|r| r.as_ref())
}
// TODO // TODO
pub fn sse_keep_alive(&self) -> Option<Duration> { pub fn sse_keep_alive(&self) -> Option<Duration> {
None None
...@@ -80,6 +96,9 @@ pub struct HttpServiceConfig { ...@@ -80,6 +96,9 @@ pub struct HttpServiceConfig {
#[builder(default = "None")] #[builder(default = "None")]
request_template: Option<RequestTemplate>, request_template: Option<RequestTemplate>,
#[builder(default = "None")]
runtime: Option<Arc<DistributedRuntime>>,
} }
impl HttpService { impl HttpService {
...@@ -134,7 +153,11 @@ impl HttpServiceConfigBuilder { ...@@ -134,7 +153,11 @@ impl HttpServiceConfigBuilder {
let config: HttpServiceConfig = self.build_internal()?; let config: HttpServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new(model_manager)); let state = if let Some(runtime) = config.runtime {
Arc::new(State::with_runtime(model_manager, runtime))
} else {
Arc::new(State::new(model_manager))
};
// enable prometheus metrics // enable prometheus metrics
let registry = metrics::Registry::new(); let registry = metrics::Registry::new();
...@@ -148,6 +171,7 @@ impl HttpServiceConfigBuilder { ...@@ -148,6 +171,7 @@ impl HttpServiceConfigBuilder {
metrics::router(registry, None), metrics::router(registry, None),
super::openai::list_models_router(state.clone(), None), super::openai::list_models_router(state.clone(), None),
super::health::health_check_router(state.clone(), None), super::health::health_check_router(state.clone(), None),
super::clear_kv_blocks::clear_kv_blocks_router(state.clone(), None),
]; ];
if config.enable_chat_endpoints { if config.enable_chat_endpoints {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment