Unverified Commit 66231cf0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: reduce / revert routing overheads, do not consider output tokens (#2182)

parent dbd33df6
......@@ -31,8 +31,8 @@ use crate::{
kv_router::{
approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent,
},
// metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
......@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
pub use_kv_events: bool,
// note: this is not actually used for now
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
}
......@@ -231,25 +232,25 @@ impl KvRouter {
let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len();
let block_size = self.block_size;
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
seq_hashes.clone(),
overlap_scores.clone(),
)
.await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision_for_request(tokens, best_worker_id)
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.await
.unwrap();
};
......@@ -262,9 +263,9 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount))
}
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
self.scheduler.push(request_id, tokens).await
/// Free all blocks associated with a request
pub async fn mark_prefill_completed(&self, request_id: &String) {
self.scheduler.mark_prefill_completed(request_id).await
}
/// Free all blocks associated with a request
......@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
......@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! {
let mut accumulated_tokens = Vec::new();
let mut total_output_length = 0usize;
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
let Some(ref output) = item.data else {
yield item;
continue;
};
if output.token_ids.is_empty() {
yield item;
continue;
}
// Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len();
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index);
if should_push {
chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear();
last_block_index = current_block_index;
if !first_push_done {
first_push_done = true;
}
if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await;
yield first_item;
}
while let Some(item) = response_stream.next().await {
yield item;
}
chooser.free(&request_id).await;
chooser.free(&context_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
......
......@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use crate::tokens::TokenBlockSequence;
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
......@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
self.kv_block_size
}
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
......@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.route_tx
.send(RouterResult {
worker_id,
local_hashes,
sequence_hashes,
})
self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
}
......
......@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
......@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterEvent {
......
......@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence;
use crate::tokens::SequenceHash;
use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -217,15 +217,13 @@ impl KvScheduler {
&self,
request_id: String,
isl_tokens: usize,
block_size: u32,
tokens: &[u32],
token_seq: Vec<SequenceHash>,
overlaps: OverlapScores,
) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let (potential_blocks, potential_tokens) =
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
......@@ -247,10 +245,10 @@ impl KvScheduler {
sequences.update_workers(new_worker_ids);
}
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(
request_id,
token_sequence,
token_seq,
isl_tokens,
response.overlap_blocks,
response.best_worker_id,
);
......@@ -258,10 +256,9 @@ impl KvScheduler {
Ok(response.best_worker_id)
}
/// Push tokens to a specific request's sequence
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
pub async fn mark_prefill_completed(&self, request_id: &String) {
let mut sequences = self.sequences.lock().await;
sequences.push(request_id, tokens)
sequences.mark_prefill_completed(request_id)
}
/// Free all blocks associated with a request
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
pub use crate::kv_router::protocols::ForwardPassMetrics;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::pipeline::network::{
ingress::push_endpoint::PushEndpoint,
PushWorkHandler,
};
use dynamo_runtime::transports::nats::{self, ServiceExt};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct KvRoutedIngress {
#[builder(setter(into))]
pub service_name: String,
#[builder(setter(into))]
pub worker_id: String,
pub nats: nats::Client,
pub service_handler: Arc<dyn PushWorkHandler>,
pub metrics_rx: watch::Receiver<Arc<ForwardPassMetrics>>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl KvRoutedIngress {
pub fn builder() -> KvRoutedIngressBuilder {
KvRoutedIngressBuilder::default()
}
pub async fn start(self) -> Result<()> {
let worker_id = self.worker_id;
log::trace!(
worker_id,
"Starting nats service: {}:{}",
self.service_name,
VERSION
);
let mut metrics_rx = self.metrics_rx;
let worker_id_clone = worker_id.clone();
let service = self
.nats
.client()
.service_builder()
.description("A handy min max service")
.stats_handler(move |name, stats| {
log::debug!(
worker_id = worker_id_clone.as_str(),
"[IN worker?] Stats for service {}: {:?}",
name,
stats
);
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.start(self.service_name.as_str(), VERSION)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
let group = service.group(self.service_name.as_str());
log::trace!(worker_id, "Starting endpoint: {}", worker_id);
// creates an endpoint for the service
let service_endpoint = group
.endpoint(worker_id.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let push_endpoint = PushEndpoint::builder()
.service_handler(self.service_handler)
.cancellation_token(self.cancellation_token)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
push_endpoint.start(service_endpoint).await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment