Unverified Commit 050906b5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: track output tokens / blocks in the Router (optional) (#5452)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9c08a2aa
...@@ -39,11 +39,18 @@ def get_aiperf_cmd( ...@@ -39,11 +39,18 @@ def get_aiperf_cmd(
num_prefix_prompts, num_prefix_prompts,
artifact_dir, artifact_dir,
url="http://localhost:8888", url="http://localhost:8888",
use_expected_osl=False,
): ):
"""Build aiperf command based on prefix ratio""" """Build aiperf command based on prefix ratio"""
prefix_length = int(isl * prefix_ratio) prefix_length = int(isl * prefix_ratio)
synthetic_input_length = int(isl * (1 - prefix_ratio)) synthetic_input_length = int(isl * (1 - prefix_ratio))
# Build nvext JSON with optional expected_output_tokens
nvext_dict = {"ignore_eos": True}
if use_expected_osl:
nvext_dict["expected_output_tokens"] = osl
nvext_json = json.dumps({"nvext": nvext_dict})
return [ return [
"aiperf", "aiperf",
"profile", "profile",
...@@ -69,7 +76,7 @@ def get_aiperf_cmd( ...@@ -69,7 +76,7 @@ def get_aiperf_cmd(
"--extra-inputs", "--extra-inputs",
"ignore_eos:true", "ignore_eos:true",
"--extra-inputs", "--extra-inputs",
'{"nvext":{"ignore_eos":true}}', nvext_json,
"--concurrency", "--concurrency",
str(concurrency), str(concurrency),
"--request-count", "--request-count",
...@@ -122,6 +129,7 @@ def run_benchmark_single_url( ...@@ -122,6 +129,7 @@ def run_benchmark_single_url(
num_prefix_prompts, num_prefix_prompts,
artifact_dir, artifact_dir,
url, url,
use_expected_osl=False,
) -> Optional[Dict]: ) -> Optional[Dict]:
"""Run aiperf benchmark for a single URL""" """Run aiperf benchmark for a single URL"""
aiperf_cmd = get_aiperf_cmd( aiperf_cmd = get_aiperf_cmd(
...@@ -136,6 +144,7 @@ def run_benchmark_single_url( ...@@ -136,6 +144,7 @@ def run_benchmark_single_url(
num_prefix_prompts, num_prefix_prompts,
artifact_dir, artifact_dir,
url, url,
use_expected_osl,
) )
logger.info(f"Running command for URL {url}: {' '.join(aiperf_cmd)}") logger.info(f"Running command for URL {url}: {' '.join(aiperf_cmd)}")
...@@ -201,6 +210,7 @@ def run_benchmark( ...@@ -201,6 +210,7 @@ def run_benchmark(
num_prefix_prompts, num_prefix_prompts,
output_dir, output_dir,
urls, urls,
use_expected_osl=False,
) -> Optional[Dict]: ) -> Optional[Dict]:
"""Run aiperf benchmark for a specific prefix ratio""" """Run aiperf benchmark for a specific prefix ratio"""
logger.info( logger.info(
...@@ -227,6 +237,7 @@ def run_benchmark( ...@@ -227,6 +237,7 @@ def run_benchmark(
num_prefix_prompts, num_prefix_prompts,
artifact_dir, artifact_dir,
urls[0], urls[0],
use_expected_osl,
) )
# Multiple URLs: split requests and concurrency # Multiple URLs: split requests and concurrency
...@@ -259,6 +270,7 @@ def run_benchmark( ...@@ -259,6 +270,7 @@ def run_benchmark(
num_prefix_prompts, num_prefix_prompts,
artifact_dir, artifact_dir,
url, url,
use_expected_osl,
) )
logger.info(f"Launching process for URL {url}: {' '.join(aiperf_cmd)}") logger.info(f"Launching process for URL {url}: {' '.join(aiperf_cmd)}")
...@@ -332,6 +344,11 @@ def main(): ...@@ -332,6 +344,11 @@ def main():
default=[0.1, 0.3, 0.5, 0.7, 0.9], default=[0.1, 0.3, 0.5, 0.7, 0.9],
help="List of prefix ratios to test", help="List of prefix ratios to test",
) )
parser.add_argument(
"--use-expected-osl",
action="store_true",
help="Pass expected_output_tokens to nvext for router tracking",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -363,6 +380,7 @@ def main(): ...@@ -363,6 +380,7 @@ def main():
args.num_prefix_prompts, args.num_prefix_prompts,
args.output_dir, args.output_dir,
args.url, # Now passing list of URLs args.url, # Now passing list of URLs
args.use_expected_osl,
) )
if result is not None: if result is not None:
......
...@@ -199,6 +199,11 @@ def main(): ...@@ -199,6 +199,11 @@ def main():
default=0, default=0,
help="Random seed for reproducibility (default: 0)", help="Random seed for reproducibility (default: 0)",
) )
parser.add_argument(
"--use-expected-osl",
action="store_true",
help="Pass expected_output_tokens to nvext for router tracking",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -223,12 +228,38 @@ def main(): ...@@ -223,12 +228,38 @@ def main():
or args.max_osl is not None or args.max_osl is not None
) )
if not needs_synthesis: if not needs_synthesis and not args.use_expected_osl:
# No synthesis needed, use original dataset # No synthesis or modification needed, use original dataset
trace_dataset_path = args.input_dataset trace_dataset_path = args.input_dataset
logger.info( logger.info(
f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}" f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
) )
elif not needs_synthesis and args.use_expected_osl:
# Only inject expected_output_tokens into nvext, no other synthesis
logger.info("Injecting expected_output_tokens into original trace dataset...")
# Read original dataset
requests = []
with open(args.input_dataset, "r") as f:
for line in f:
requests.append(json.loads(line.strip()))
# Inject expected_output_tokens into nvext for each request
for request in requests:
osl = request.get("output_tokens", 0)
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
# Write modified data to output directory
trace_dataset_path = os.path.join(
args.output_dir, "trace_with_expected_osl.jsonl"
)
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Modified trace data saved to: {trace_dataset_path}")
else: else:
# Generate synthetic data based on input dataset # Generate synthetic data based on input dataset
logger.info("Generating synthetic trace data...") logger.info("Generating synthetic trace data...")
...@@ -290,6 +321,17 @@ def main(): ...@@ -290,6 +321,17 @@ def main():
synthetic_trace_filename = "synthetic_trace.jsonl" synthetic_trace_filename = "synthetic_trace.jsonl"
trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename) trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename)
# Optionally inject expected_output_tokens into nvext for each request
if args.use_expected_osl:
for request in requests:
# Get the output_tokens (OSL) for this request
osl = request.get("output_tokens", 0)
# Initialize or update nvext with expected_output_tokens
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
logger.info("Injected expected_output_tokens into nvext for each request")
# Write synthetic data to file # Write synthetic data to file
with open(trace_dataset_path, "w") as f: with open(trace_dataset_path, "w") as f:
for request in requests: for request in requests:
......
...@@ -211,6 +211,13 @@ def parse_args(): ...@@ -211,6 +211,13 @@ def parse_args():
default=True, default=True,
help="KV Router: When tracking active blocks, do not assume KV cache reuse (generate random hashes instead of computing actual block hashes). Useful when KV cache reuse is not expected. By default, KV cache reuse is assumed.", help="KV Router: When tracking active blocks, do not assume KV cache reuse (generate random hashes instead of computing actual block hashes). Useful when KV cache reuse is not expected. By default, KV cache reuse is assumed.",
) )
parser.add_argument(
"--track-output-blocks",
action="store_true",
dest="router_track_output_blocks",
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens. By default, output blocks are not tracked.",
)
parser.add_argument( parser.add_argument(
"--enforce-disagg", "--enforce-disagg",
action="store_true", action="store_true",
...@@ -354,6 +361,7 @@ async def async_main(): ...@@ -354,6 +361,7 @@ async def async_main():
use_kv_events=flags.use_kv_events, use_kv_events=flags.use_kv_events,
router_replica_sync=flags.router_replica_sync, router_replica_sync=flags.router_replica_sync,
router_track_active_blocks=flags.router_track_active_blocks, router_track_active_blocks=flags.router_track_active_blocks,
router_track_output_blocks=flags.router_track_output_blocks,
router_assume_kv_reuse=flags.router_assume_kv_reuse, router_assume_kv_reuse=flags.router_assume_kv_reuse,
router_snapshot_threshold=flags.router_snapshot_threshold, router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states, router_reset_states=flags.router_reset_states,
......
...@@ -227,6 +227,14 @@ def parse_args(): ...@@ -227,6 +227,14 @@ def parse_args():
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)", help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
) )
parser.add_argument(
"--track-output-blocks",
action="store_true",
dest="router_track_output_blocks",
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens (default: False)",
)
parser.add_argument( parser.add_argument(
"--router-ttl-secs", "--router-ttl-secs",
type=float, type=float,
...@@ -275,6 +283,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -275,6 +283,7 @@ async def worker(runtime: DistributedRuntime):
f"router_replica_sync={args.router_replica_sync}, " f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, " f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}, " f"router_track_active_blocks={args.router_track_active_blocks}, "
f"router_track_output_blocks={args.router_track_output_blocks}, "
f"router_ttl_secs={args.router_ttl_secs}, " f"router_ttl_secs={args.router_ttl_secs}, "
f"router_max_tree_size={args.router_max_tree_size}, " f"router_max_tree_size={args.router_max_tree_size}, "
f"router_prune_target_ratio={args.router_prune_target_ratio}" f"router_prune_target_ratio={args.router_prune_target_ratio}"
...@@ -289,6 +298,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -289,6 +298,7 @@ async def worker(runtime: DistributedRuntime):
router_snapshot_threshold=args.router_snapshot_threshold, router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states, router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks, router_track_active_blocks=args.router_track_active_blocks,
router_track_output_blocks=args.router_track_output_blocks,
router_ttl_secs=args.router_ttl_secs, router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size, router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio, router_prune_target_ratio=args.router_prune_target_ratio,
......
...@@ -186,6 +186,7 @@ impl Flags { ...@@ -186,6 +186,7 @@ impl Flags {
self.use_kv_events, self.use_kv_events,
self.router_replica_sync, self.router_replica_sync,
self.router_track_active_blocks, self.router_track_active_blocks,
None, // track_output_blocks
// defaulting below args (no longer maintaining new flags for dynamo-run) // defaulting below args (no longer maintaining new flags for dynamo-run)
None, // assume_kv_reuse None, // assume_kv_reuse
None, None,
......
...@@ -513,6 +513,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline( ...@@ -513,6 +513,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
Some(use_kv_events), Some(use_kv_events),
Some(router_replica_sync), Some(router_replica_sync),
None, // track_active_blocks None, // track_active_blocks
None, // track_output_blocks
None, // assume_kv_reuse None, // assume_kv_reuse
None, // router_snapshot_threshold None, // router_snapshot_threshold
None, // router_reset_states None, // router_reset_states
...@@ -930,7 +931,13 @@ pub unsafe extern "C" fn dynamo_router_add_request( ...@@ -930,7 +931,13 @@ pub unsafe extern "C" fn dynamo_router_add_request(
}; };
kv_router kv_router
.add_request(request_id_clone.clone(), &tokens, overlap_blocks, worker) .add_request(
request_id_clone.clone(),
&tokens,
overlap_blocks,
None,
worker,
)
.await; .await;
tracing::debug!( tracing::debug!(
......
...@@ -50,7 +50,7 @@ impl KvRouterConfig { ...@@ -50,7 +50,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -58,6 +58,7 @@ impl KvRouterConfig { ...@@ -58,6 +58,7 @@ impl KvRouterConfig {
use_kv_events: bool, use_kv_events: bool,
router_replica_sync: bool, router_replica_sync: bool,
router_track_active_blocks: bool, router_track_active_blocks: bool,
router_track_output_blocks: bool,
router_assume_kv_reuse: bool, router_assume_kv_reuse: bool,
router_snapshot_threshold: Option<u32>, router_snapshot_threshold: Option<u32>,
router_reset_states: bool, router_reset_states: bool,
...@@ -72,6 +73,7 @@ impl KvRouterConfig { ...@@ -72,6 +73,7 @@ impl KvRouterConfig {
use_kv_events, use_kv_events,
router_replica_sync, router_replica_sync,
router_track_active_blocks, router_track_active_blocks,
router_track_output_blocks,
router_assume_kv_reuse, router_assume_kv_reuse,
router_snapshot_threshold, router_snapshot_threshold,
router_reset_states, router_reset_states,
......
...@@ -1088,6 +1088,7 @@ class KvRouterConfig: ...@@ -1088,6 +1088,7 @@ class KvRouterConfig:
use_kv_events: bool = True, use_kv_events: bool = True,
router_replica_sync: bool = False, router_replica_sync: bool = False,
router_track_active_blocks: bool = True, router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True, router_assume_kv_reuse: bool = True,
router_snapshot_threshold: Optional[int] = 1000000, router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False, router_reset_states: bool = False,
...@@ -1104,6 +1105,9 @@ class KvRouterConfig: ...@@ -1104,6 +1105,9 @@ class KvRouterConfig:
use_kv_events: Whether to use KV events from workers (default: True) use_kv_events: Whether to use KV events from workers (default: True)
router_replica_sync: Enable replica synchronization (default: False) router_replica_sync: Enable replica synchronization (default: False)
router_track_active_blocks: Track active blocks for load balancing (default: True) router_track_active_blocks: Track active blocks for load balancing (default: True)
router_track_output_blocks: Track output blocks during generation (default: False).
When enabled, the router adds placeholder blocks as tokens are generated
and applies fractional decay based on progress toward expected_output_tokens.
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True). router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes. When True, computes actual block hashes. When False, generates random hashes.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000) router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
......
...@@ -138,6 +138,11 @@ pub struct KvRouterConfig { ...@@ -138,6 +138,11 @@ pub struct KvRouterConfig {
/// Whether to track active blocks in the router (default: true) /// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool, pub router_track_active_blocks: bool,
/// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward expected_output_tokens.
pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true). /// Whether to assume KV cache reuse when tracking active blocks (default: true).
/// When true, computes actual block hashes for sequence tracking. /// When true, computes actual block hashes for sequence tracking.
/// When false, generates random hashes (assuming no KV cache reuse). /// When false, generates random hashes (assuming no KV cache reuse).
...@@ -167,6 +172,7 @@ impl Default for KvRouterConfig { ...@@ -167,6 +172,7 @@ impl Default for KvRouterConfig {
use_kv_events: true, use_kv_events: true,
router_replica_sync: false, router_replica_sync: false,
router_track_active_blocks: true, router_track_active_blocks: true,
router_track_output_blocks: false,
router_assume_kv_reuse: true, router_assume_kv_reuse: true,
router_snapshot_threshold: Some(1000000), router_snapshot_threshold: Some(1000000),
router_reset_states: false, router_reset_states: false,
...@@ -187,6 +193,7 @@ impl KvRouterConfig { ...@@ -187,6 +193,7 @@ impl KvRouterConfig {
use_kv_events: Option<bool>, use_kv_events: Option<bool>,
replica_sync: Option<bool>, replica_sync: Option<bool>,
track_active_blocks: Option<bool>, track_active_blocks: Option<bool>,
track_output_blocks: Option<bool>,
assume_kv_reuse: Option<bool>, assume_kv_reuse: Option<bool>,
router_snapshot_threshold: Option<Option<u32>>, router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>, router_reset_states: Option<bool>,
...@@ -202,6 +209,8 @@ impl KvRouterConfig { ...@@ -202,6 +209,8 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks), .unwrap_or(default.router_track_active_blocks),
router_track_output_blocks: track_output_blocks
.unwrap_or(default.router_track_output_blocks),
router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse), router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
router_snapshot_threshold: router_snapshot_threshold router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold), .unwrap_or(default.router_snapshot_threshold),
...@@ -522,6 +531,7 @@ impl KvRouter { ...@@ -522,6 +531,7 @@ impl KvRouter {
request_id: String, request_id: String,
tokens: &[u32], tokens: &[u32],
overlap_blocks: u32, overlap_blocks: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
...@@ -537,6 +547,7 @@ impl KvRouter { ...@@ -537,6 +547,7 @@ impl KvRouter {
maybe_seq_hashes, maybe_seq_hashes,
isl_tokens, isl_tokens,
overlap_blocks, overlap_blocks,
expected_output_tokens,
worker, worker,
) )
.await .await
...@@ -553,6 +564,16 @@ impl KvRouter { ...@@ -553,6 +564,16 @@ impl KvRouter {
self.scheduler.free(request_id).await self.scheduler.free(request_id).await
} }
pub async fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.scheduler
.add_output_block(request_id, decay_fraction)
.await
}
pub fn block_size(&self) -> u32 { pub fn block_size(&self) -> u32 {
self.block_size self.block_size
} }
...@@ -763,6 +784,12 @@ impl KvPushRouter { ...@@ -763,6 +784,12 @@ impl KvPushRouter {
.get_overlap_blocks(&request.token_ids, worker) .get_overlap_blocks(&request.token_ids, worker)
.await?; .await?;
// Extract expected_output_tokens from routing hints
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
// Perform add_request if this router handles local updates // Perform add_request if this router handles local updates
if !is_query_only && handle_local_updates { if !is_query_only && handle_local_updates {
self.chooser self.chooser
...@@ -770,6 +797,7 @@ impl KvPushRouter { ...@@ -770,6 +797,7 @@ impl KvPushRouter {
context_id.to_string(), context_id.to_string(),
&request.token_ids, &request.token_ids,
overlap_blocks, overlap_blocks,
expected_output_tokens,
worker, worker,
) )
.await; .await;
...@@ -912,6 +940,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -912,6 +940,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
// Route to worker // Route to worker
let isl_tokens = request.token_ids.len();
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
let track_output_blocks =
self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = Some(dp_rank); backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
...@@ -927,6 +963,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -927,6 +963,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false; let mut prefill_marked = false;
// Output block tracking state
let mut cumulative_osl: usize = 0;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop { loop {
tokio::select! { tokio::select! {
biased; biased;
...@@ -955,6 +995,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -955,6 +995,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
} }
// Track output blocks if enabled
if track_output_blocks {
let new_tokens = item.data.as_ref()
.map(|d| d.token_ids.len())
.unwrap_or(0);
cumulative_osl += new_tokens;
let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
// New block boundary crossed - add output block with decay
// Clamp eot to min 1 to avoid division by zero, and result to min 0.0
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
"Failed to add output block for request {context_id}: {e}"
);
}
current_total_blocks = new_total_blocks;
}
}
yield item; yield item;
} }
} }
......
...@@ -293,6 +293,7 @@ pub enum ActiveSequenceEventData { ...@@ -293,6 +293,7 @@ pub enum ActiveSequenceEventData {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>,
}, },
Free, Free,
MarkPrefillCompleted, MarkPrefillCompleted,
......
...@@ -234,6 +234,7 @@ impl KvScheduler { ...@@ -234,6 +234,7 @@ impl KvScheduler {
request.token_seq, request.token_seq,
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.overlap_blocks,
None, // expected_output_tokens not available in scheduler loop
selection.worker, selection.worker,
) )
.await .await
...@@ -304,10 +305,18 @@ impl KvScheduler { ...@@ -304,10 +305,18 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.slots self.slots
.add_request(request_id, token_sequence, isl, overlap, worker) .add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
worker,
)
.await .await
} }
...@@ -321,6 +330,16 @@ impl KvScheduler { ...@@ -321,6 +330,16 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await self.slots.free(&request_id.to_string()).await
} }
pub async fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
.await
}
pub async fn get_potential_loads( pub async fn get_potential_loads(
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
......
...@@ -80,8 +80,16 @@ pub struct ActiveSequences { ...@@ -80,8 +80,16 @@ pub struct ActiveSequences {
prefill_tokens: HashMap<RequestId, usize>, prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, Weak<()>>, unique_blocks: HashMap<SequenceHash, Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// it contributes the fractional value instead of 1 to active_blocks()
fractional_blocks: HashMap<SequenceHash, f64>,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: usize,
...@@ -104,7 +112,9 @@ impl ActiveSequences { ...@@ -104,7 +112,9 @@ impl ActiveSequences {
Self { Self {
active_seqs: HashMap::new(), active_seqs: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
expected_output_tokens: HashMap::new(),
unique_blocks: HashMap::new(), unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
block_size, block_size,
active_tokens: 0, active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION, expiry_timer: Instant::now() + EXPIRY_DURATION,
...@@ -129,11 +139,37 @@ impl ActiveSequences { ...@@ -129,11 +139,37 @@ impl ActiveSequences {
&& weak.strong_count() == 0 && weak.strong_count() == 0
{ {
self.unique_blocks.remove(block); self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
} }
} }
pub fn active_blocks(&self) -> usize { pub fn active_blocks(&self) -> usize {
self.unique_blocks.len() let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value
count = count - 1.0 + frac;
}
}
count.round() as usize
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
if Rc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
}
}
} }
/// Add a new request with its initial tokens /// Add a new request with its initial tokens
...@@ -144,6 +180,7 @@ impl ActiveSequences { ...@@ -144,6 +180,7 @@ impl ActiveSequences {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>,
) -> HashSet<RequestId> { ) -> HashSet<RequestId> {
// Check for double-add and log error, returning early // Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) { if self.active_seqs.contains_key(&request_id) {
...@@ -159,6 +196,12 @@ impl ActiveSequences { ...@@ -159,6 +196,12 @@ impl ActiveSequences {
.insert(request_id.clone(), prefill_tokens); .insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens; self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence { if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
.iter() .iter()
...@@ -231,6 +274,9 @@ impl ActiveSequences { ...@@ -231,6 +274,9 @@ impl ActiveSequences {
self.expiry_requests.remove(request_id); self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence // Remove from active_seqs and get the token sequence
let token_seq = match self.active_seqs.remove(request_id) { let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq, Some(seq) => seq,
...@@ -249,6 +295,46 @@ impl ActiveSequences { ...@@ -249,6 +295,46 @@ impl ActiveSequences {
self.active_blocks() self.active_blocks()
} }
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block(
&mut self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> bool {
// Check if request exists first (immutable borrow)
if !self.active_seqs.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block");
return false;
}
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
// Touch the block (adds to unique_blocks)
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
.get_mut(request_id)
.unwrap()
.push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac);
}
true
}
/// Force expiry of stale requests if the timer has elapsed /// Force expiry of stale requests if the timer has elapsed
/// Returns the set of expired request IDs that were removed /// Returns the set of expired request IDs that were removed
pub fn force_expiry(&mut self) -> HashSet<RequestId> { pub fn force_expiry(&mut self) -> HashSet<RequestId> {
...@@ -279,6 +365,7 @@ enum UpdateSequences { ...@@ -279,6 +365,7 @@ enum UpdateSequences {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>, resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
}, },
Free { Free {
...@@ -287,6 +374,11 @@ enum UpdateSequences { ...@@ -287,6 +374,11 @@ enum UpdateSequences {
MarkPrefillCompleted { MarkPrefillCompleted {
request_id: RequestId, request_id: RequestId,
}, },
AddOutputBlock {
request_id: RequestId,
decay_fraction: Option<f64>,
resp_tx: tokio::sync::oneshot::Sender<bool>,
},
NewBlocks { NewBlocks {
token_sequence: Arc<Vec<SequenceHash>>, token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>, resp_tx: tokio::sync::oneshot::Sender<usize>,
...@@ -428,9 +520,10 @@ impl ActiveSequencesMultiWorker { ...@@ -428,9 +520,10 @@ impl ActiveSequencesMultiWorker {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
expected_output_tokens,
resp_tx, resp_tx,
} => { } => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap); let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap, expected_output_tokens);
let _ = resp_tx.send(removed); let _ = resp_tx.send(removed);
} }
UpdateSequences::Free { request_id } => { UpdateSequences::Free { request_id } => {
...@@ -439,6 +532,14 @@ impl ActiveSequencesMultiWorker { ...@@ -439,6 +532,14 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::MarkPrefillCompleted { request_id } => { UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id); active_sequences.mark_prefill_completed(&request_id);
} }
UpdateSequences::AddOutputBlock {
request_id,
decay_fraction,
resp_tx,
} => {
let success = active_sequences.add_output_block(&request_id, decay_fraction);
let _ = resp_tx.send(success);
}
UpdateSequences::NewBlocks { UpdateSequences::NewBlocks {
token_sequence, token_sequence,
resp_tx, resp_tx,
...@@ -535,6 +636,7 @@ impl ActiveSequencesMultiWorker { ...@@ -535,6 +636,7 @@ impl ActiveSequencesMultiWorker {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
expected_output_tokens,
} => { } => {
request_to_worker.insert(event.request_id.clone(), event.worker); request_to_worker.insert(event.request_id.clone(), event.worker);
...@@ -546,6 +648,7 @@ impl ActiveSequencesMultiWorker { ...@@ -546,6 +648,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: token_sequence.clone(), token_sequence: token_sequence.clone(),
isl: *isl, isl: *isl,
overlap: *overlap, overlap: *overlap,
expected_output_tokens: *expected_output_tokens,
resp_tx, resp_tx,
}); });
} else { } else {
...@@ -643,6 +746,7 @@ impl ActiveSequencesMultiWorker { ...@@ -643,6 +746,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
// Check for worker existence // Check for worker existence
...@@ -670,6 +774,7 @@ impl ActiveSequencesMultiWorker { ...@@ -670,6 +774,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: token_sequence.clone(), token_sequence: token_sequence.clone(),
isl, isl,
overlap, overlap,
expected_output_tokens,
}, },
router_id: self.router_id, router_id: self.router_id,
}; };
...@@ -689,6 +794,7 @@ impl ActiveSequencesMultiWorker { ...@@ -689,6 +794,7 @@ impl ActiveSequencesMultiWorker {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
expected_output_tokens,
resp_tx, resp_tx,
}) })
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
...@@ -804,6 +910,59 @@ impl ActiveSequencesMultiWorker { ...@@ -804,6 +910,59 @@ impl ActiveSequencesMultiWorker {
Ok(()) Ok(())
} }
/// Add an output block with optional fractional decay weight
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
pub async fn add_output_block(
&self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker
self.senders
.get(&worker)
.unwrap()
.send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(),
decay_fraction,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Wait for response
let success = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
if !success {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(),
});
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad /// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) { async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else { let Some(sender) = self.senders.get(&worker) else {
...@@ -1028,15 +1187,15 @@ mod tests { ...@@ -1028,15 +1187,15 @@ mod tests {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0); seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0); seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4); seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(), 16);
...@@ -1112,6 +1271,7 @@ mod tests { ...@@ -1112,6 +1271,7 @@ mod tests {
Some(vec![0, 1, 2]), Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size) 12, // ISL (3 blocks * 4 block_size)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(0, 0), WorkerWithDpRank::new(0, 0),
) )
.await?; .await?;
...@@ -1123,6 +1283,7 @@ mod tests { ...@@ -1123,6 +1283,7 @@ mod tests {
Some(vec![3, 4]), Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size) 8, // ISL (2 blocks * 4 block_size)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(0, 1), WorkerWithDpRank::new(0, 1),
) )
.await?; .await?;
...@@ -1134,6 +1295,7 @@ mod tests { ...@@ -1134,6 +1295,7 @@ mod tests {
Some(vec![0, 1, 2, 3]), Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size) 16, // ISL (4 blocks * 4 block_size)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(1, 0), WorkerWithDpRank::new(1, 0),
) )
.await?; .await?;
...@@ -1268,6 +1430,7 @@ mod tests { ...@@ -1268,6 +1430,7 @@ mod tests {
None, // No token sequence None, // No token sequence
12, // ISL (12 tokens) 12, // ISL (12 tokens)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(0), WorkerWithDpRank::from_worker_id(0),
) )
.await?; .await?;
...@@ -1279,6 +1442,7 @@ mod tests { ...@@ -1279,6 +1442,7 @@ mod tests {
None, // No token sequence None, // No token sequence
8, // ISL (8 tokens) 8, // ISL (8 tokens)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
) )
.await?; .await?;
...@@ -1290,6 +1454,7 @@ mod tests { ...@@ -1290,6 +1454,7 @@ mod tests {
None, // No token sequence None, // No token sequence
16, // ISL (16 tokens) 16, // ISL (16 tokens)
0, // no overlap 0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(2), WorkerWithDpRank::from_worker_id(2),
) )
.await?; .await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment