Unverified Commit 66231cf0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: reduce / revert routing overheads, do not consider output tokens (#2182)

parent dbd33df6
...@@ -31,8 +31,8 @@ use crate::{ ...@@ -31,8 +31,8 @@ use crate::{
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError, compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
OverlapScores, RouterEvent, KvRouterError, OverlapScores, RouterEvent,
}, },
// metrics_aggregator::EndpointCollector, // metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
...@@ -71,7 +71,8 @@ pub struct KvRouterConfig { ...@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
pub use_kv_events: bool, pub use_kv_events: bool,
// note: this is not actually used for now // TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
} }
...@@ -231,25 +232,25 @@ impl KvRouter { ...@@ -231,25 +232,25 @@ impl KvRouter {
let _guard = self.find_best_match_mutex.lock().await; let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size;
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let best_worker_id = self let best_worker_id = self
.scheduler .scheduler
.schedule( .schedule(
context_id.to_string(), context_id.to_string(),
isl_tokens, isl_tokens,
block_size, seq_hashes.clone(),
tokens,
overlap_scores.clone(), overlap_scores.clone(),
) )
.await?; .await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer indexer
.process_routing_decision_for_request(tokens, best_worker_id) .process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.await .await
.unwrap(); .unwrap();
}; };
...@@ -262,9 +263,9 @@ impl KvRouter { ...@@ -262,9 +263,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount)) Ok((best_worker_id, overlap_amount))
} }
/// Push tokens to a specific request's sequence /// Free all blocks associated with a request
pub async fn push(&self, request_id: &String, tokens: &[u32]) { pub async fn mark_prefill_completed(&self, request_id: &String) {
self.scheduler.push(request_id, tokens).await self.scheduler.mark_prefill_completed(request_id).await
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = request.context().clone(); let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
...@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response]); let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context)); return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} }
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context(); let stream_context = response_stream.context();
let chooser = self.chooser.clone(); let chooser = self.chooser.clone();
let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new(); if let Some(first_item) = response_stream.next().await {
let mut total_output_length = 0usize; chooser.mark_prefill_completed(&context_id).await;
let mut last_block_index = (isl.saturating_sub(1)) / block_size; yield first_item;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
let Some(ref output) = item.data else {
yield item;
continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
}
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
} }
while let Some(item) = response_stream.next().await {
yield item; yield item;
} }
chooser.free(&request_id).await; chooser.free(&context_id).await;
}); });
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
......
...@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot}; ...@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::tokens::TokenBlockSequence; use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
...@@ -295,6 +295,26 @@ impl ApproxKvIndexer { ...@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
self.kv_block_size self.kv_block_size
} }
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub async fn process_routing_decision_for_request( pub async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
...@@ -309,16 +329,8 @@ impl ApproxKvIndexer { ...@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash()) .map(|b| b.sequence_hash())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.route_tx self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await .await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
} }
} }
......
...@@ -63,6 +63,7 @@ use xxhash_rust::xxh3; ...@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337; pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::protocols::*; use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router. /// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
...@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc ...@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
.collect() .collect()
} }
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. /// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterEvent { pub struct RouterEvent {
......
...@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics; ...@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker; use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig; use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence; use crate::tokens::SequenceHash;
use dynamo_runtime::component::Instance; use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -217,15 +217,13 @@ impl KvScheduler { ...@@ -217,15 +217,13 @@ impl KvScheduler {
&self, &self,
request_id: String, request_id: String,
isl_tokens: usize, isl_tokens: usize,
block_size: u32, token_seq: Vec<SequenceHash>,
tokens: &[u32],
overlaps: OverlapScores, overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let (potential_blocks, potential_tokens) = let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone()); sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
...@@ -247,10 +245,10 @@ impl KvScheduler { ...@@ -247,10 +245,10 @@ impl KvScheduler {
sequences.update_workers(new_worker_ids); sequences.update_workers(new_worker_ids);
} }
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request( sequences.add_request(
request_id, request_id,
token_sequence, token_seq,
isl_tokens,
response.overlap_blocks, response.overlap_blocks,
response.best_worker_id, response.best_worker_id,
); );
...@@ -258,10 +256,9 @@ impl KvScheduler { ...@@ -258,10 +256,9 @@ impl KvScheduler {
Ok(response.best_worker_id) Ok(response.best_worker_id)
} }
/// Push tokens to a specific request's sequence pub async fn mark_prefill_completed(&self, request_id: &String) {
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
let mut sequences = self.sequences.lock().await; let mut sequences = self.sequences.lock().await;
sequences.push(request_id, tokens) sequences.mark_prefill_completed(request_id)
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
pub use crate::kv_router::protocols::ForwardPassMetrics;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::pipeline::network::{
ingress::push_endpoint::PushEndpoint,
PushWorkHandler,
};
use dynamo_runtime::transports::nats::{self, ServiceExt};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct KvRoutedIngress {
#[builder(setter(into))]
pub service_name: String,
#[builder(setter(into))]
pub worker_id: String,
pub nats: nats::Client,
pub service_handler: Arc<dyn PushWorkHandler>,
pub metrics_rx: watch::Receiver<Arc<ForwardPassMetrics>>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl KvRoutedIngress {
pub fn builder() -> KvRoutedIngressBuilder {
KvRoutedIngressBuilder::default()
}
pub async fn start(self) -> Result<()> {
let worker_id = self.worker_id;
log::trace!(
worker_id,
"Starting nats service: {}:{}",
self.service_name,
VERSION
);
let mut metrics_rx = self.metrics_rx;
let worker_id_clone = worker_id.clone();
let service = self
.nats
.client()
.service_builder()
.description("A handy min max service")
.stats_handler(move |name, stats| {
log::debug!(
worker_id = worker_id_clone.as_str(),
"[IN worker?] Stats for service {}: {:?}",
name,
stats
);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.start(self.service_name.as_str(), VERSION)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
let group = service.group(self.service_name.as_str());
log::trace!(worker_id, "Starting endpoint: {}", worker_id);
// creates an endpoint for the service
let service_endpoint = group
.endpoint(worker_id.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let push_endpoint = PushEndpoint::builder()
.service_handler(self.service_handler)
.cancellation_token(self.cancellation_token)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
push_endpoint.start(service_endpoint).await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment