Unverified Commit 050906b5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: track output tokens / blocks in the Router (optional) (#5452)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9c08a2aa
......@@ -39,11 +39,18 @@ def get_aiperf_cmd(
num_prefix_prompts,
artifact_dir,
url="http://localhost:8888",
use_expected_osl=False,
):
"""Build aiperf command based on prefix ratio"""
prefix_length = int(isl * prefix_ratio)
synthetic_input_length = int(isl * (1 - prefix_ratio))
# Build nvext JSON with optional expected_output_tokens
nvext_dict = {"ignore_eos": True}
if use_expected_osl:
nvext_dict["expected_output_tokens"] = osl
nvext_json = json.dumps({"nvext": nvext_dict})
return [
"aiperf",
"profile",
......@@ -69,7 +76,7 @@ def get_aiperf_cmd(
"--extra-inputs",
"ignore_eos:true",
"--extra-inputs",
'{"nvext":{"ignore_eos":true}}',
nvext_json,
"--concurrency",
str(concurrency),
"--request-count",
......@@ -122,6 +129,7 @@ def run_benchmark_single_url(
num_prefix_prompts,
artifact_dir,
url,
use_expected_osl=False,
) -> Optional[Dict]:
"""Run aiperf benchmark for a single URL"""
aiperf_cmd = get_aiperf_cmd(
......@@ -136,6 +144,7 @@ def run_benchmark_single_url(
num_prefix_prompts,
artifact_dir,
url,
use_expected_osl,
)
logger.info(f"Running command for URL {url}: {' '.join(aiperf_cmd)}")
......@@ -201,6 +210,7 @@ def run_benchmark(
num_prefix_prompts,
output_dir,
urls,
use_expected_osl=False,
) -> Optional[Dict]:
"""Run aiperf benchmark for a specific prefix ratio"""
logger.info(
......@@ -227,6 +237,7 @@ def run_benchmark(
num_prefix_prompts,
artifact_dir,
urls[0],
use_expected_osl,
)
# Multiple URLs: split requests and concurrency
......@@ -259,6 +270,7 @@ def run_benchmark(
num_prefix_prompts,
artifact_dir,
url,
use_expected_osl,
)
logger.info(f"Launching process for URL {url}: {' '.join(aiperf_cmd)}")
......@@ -332,6 +344,11 @@ def main():
default=[0.1, 0.3, 0.5, 0.7, 0.9],
help="List of prefix ratios to test",
)
parser.add_argument(
"--use-expected-osl",
action="store_true",
help="Pass expected_output_tokens to nvext for router tracking",
)
args = parser.parse_args()
......@@ -363,6 +380,7 @@ def main():
args.num_prefix_prompts,
args.output_dir,
args.url, # Now passing list of URLs
args.use_expected_osl,
)
if result is not None:
......
......@@ -199,6 +199,11 @@ def main():
default=0,
help="Random seed for reproducibility (default: 0)",
)
parser.add_argument(
"--use-expected-osl",
action="store_true",
help="Pass expected_output_tokens to nvext for router tracking",
)
args = parser.parse_args()
......@@ -223,12 +228,38 @@ def main():
or args.max_osl is not None
)
if not needs_synthesis:
# No synthesis needed, use original dataset
if not needs_synthesis and not args.use_expected_osl:
# No synthesis or modification needed, use original dataset
trace_dataset_path = args.input_dataset
logger.info(
f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
)
elif not needs_synthesis and args.use_expected_osl:
# Only inject expected_output_tokens into nvext, no other synthesis
logger.info("Injecting expected_output_tokens into original trace dataset...")
# Read original dataset
requests = []
with open(args.input_dataset, "r") as f:
for line in f:
requests.append(json.loads(line.strip()))
# Inject expected_output_tokens into nvext for each request
for request in requests:
osl = request.get("output_tokens", 0)
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
# Write modified data to output directory
trace_dataset_path = os.path.join(
args.output_dir, "trace_with_expected_osl.jsonl"
)
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Modified trace data saved to: {trace_dataset_path}")
else:
# Generate synthetic data based on input dataset
logger.info("Generating synthetic trace data...")
......@@ -290,6 +321,17 @@ def main():
synthetic_trace_filename = "synthetic_trace.jsonl"
trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename)
# Optionally inject expected_output_tokens into nvext for each request
if args.use_expected_osl:
for request in requests:
# Get the output_tokens (OSL) for this request
osl = request.get("output_tokens", 0)
# Initialize or update nvext with expected_output_tokens
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
logger.info("Injected expected_output_tokens into nvext for each request")
# Write synthetic data to file
with open(trace_dataset_path, "w") as f:
for request in requests:
......
......@@ -211,6 +211,13 @@ def parse_args():
default=True,
help="KV Router: When tracking active blocks, do not assume KV cache reuse (generate random hashes instead of computing actual block hashes). Useful when KV cache reuse is not expected. By default, KV cache reuse is assumed.",
)
parser.add_argument(
"--track-output-blocks",
action="store_true",
dest="router_track_output_blocks",
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens. By default, output blocks are not tracked.",
)
parser.add_argument(
"--enforce-disagg",
action="store_true",
......@@ -354,6 +361,7 @@ async def async_main():
use_kv_events=flags.use_kv_events,
router_replica_sync=flags.router_replica_sync,
router_track_active_blocks=flags.router_track_active_blocks,
router_track_output_blocks=flags.router_track_output_blocks,
router_assume_kv_reuse=flags.router_assume_kv_reuse,
router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states,
......
......@@ -227,6 +227,14 @@ def parse_args():
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
)
parser.add_argument(
"--track-output-blocks",
action="store_true",
dest="router_track_output_blocks",
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens (default: False)",
)
parser.add_argument(
"--router-ttl-secs",
type=float,
......@@ -275,6 +283,7 @@ async def worker(runtime: DistributedRuntime):
f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}, "
f"router_track_output_blocks={args.router_track_output_blocks}, "
f"router_ttl_secs={args.router_ttl_secs}, "
f"router_max_tree_size={args.router_max_tree_size}, "
f"router_prune_target_ratio={args.router_prune_target_ratio}"
......@@ -289,6 +298,7 @@ async def worker(runtime: DistributedRuntime):
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks,
router_track_output_blocks=args.router_track_output_blocks,
router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio,
......
......@@ -186,6 +186,7 @@ impl Flags {
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
None, // track_output_blocks
// defaulting below args (no longer maintaining new flags for dynamo-run)
None, // assume_kv_reuse
None,
......
......@@ -513,6 +513,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
Some(use_kv_events),
Some(router_replica_sync),
None, // track_active_blocks
None, // track_output_blocks
None, // assume_kv_reuse
None, // router_snapshot_threshold
None, // router_reset_states
......@@ -930,7 +931,13 @@ pub unsafe extern "C" fn dynamo_router_add_request(
};
kv_router
.add_request(request_id_clone.clone(), &tokens, overlap_blocks, worker)
.add_request(
request_id_clone.clone(),
&tokens,
overlap_blocks,
None,
worker,
)
.await;
tracing::debug!(
......
......@@ -50,7 +50,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -58,6 +58,7 @@ impl KvRouterConfig {
use_kv_events: bool,
router_replica_sync: bool,
router_track_active_blocks: bool,
router_track_output_blocks: bool,
router_assume_kv_reuse: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
......@@ -72,6 +73,7 @@ impl KvRouterConfig {
use_kv_events,
router_replica_sync,
router_track_active_blocks,
router_track_output_blocks,
router_assume_kv_reuse,
router_snapshot_threshold,
router_reset_states,
......
......@@ -1088,6 +1088,7 @@ class KvRouterConfig:
use_kv_events: bool = True,
router_replica_sync: bool = False,
router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True,
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
......@@ -1104,6 +1105,9 @@ class KvRouterConfig:
use_kv_events: Whether to use KV events from workers (default: True)
router_replica_sync: Enable replica synchronization (default: False)
router_track_active_blocks: Track active blocks for load balancing (default: True)
router_track_output_blocks: Track output blocks during generation (default: False).
When enabled, the router adds placeholder blocks as tokens are generated
and applies fractional decay based on progress toward expected_output_tokens.
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
......
......@@ -138,6 +138,11 @@ pub struct KvRouterConfig {
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
/// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward expected_output_tokens.
pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true).
/// When true, computes actual block hashes for sequence tracking.
/// When false, generates random hashes (assuming no KV cache reuse).
......@@ -167,6 +172,7 @@ impl Default for KvRouterConfig {
use_kv_events: true,
router_replica_sync: false,
router_track_active_blocks: true,
router_track_output_blocks: false,
router_assume_kv_reuse: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
......@@ -187,6 +193,7 @@ impl KvRouterConfig {
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
track_output_blocks: Option<bool>,
assume_kv_reuse: Option<bool>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
......@@ -202,6 +209,8 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
router_track_output_blocks: track_output_blocks
.unwrap_or(default.router_track_output_blocks),
router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
......@@ -522,6 +531,7 @@ impl KvRouter {
request_id: String,
tokens: &[u32],
overlap_blocks: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
) {
let isl_tokens = tokens.len();
......@@ -537,6 +547,7 @@ impl KvRouter {
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
expected_output_tokens,
worker,
)
.await
......@@ -553,6 +564,16 @@ impl KvRouter {
self.scheduler.free(request_id).await
}
pub async fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.scheduler
.add_output_block(request_id, decay_fraction)
.await
}
pub fn block_size(&self) -> u32 {
self.block_size
}
......@@ -763,6 +784,12 @@ impl KvPushRouter {
.get_overlap_blocks(&request.token_ids, worker)
.await?;
// Extract expected_output_tokens from routing hints
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
// Perform add_request if this router handles local updates
if !is_query_only && handle_local_updates {
self.chooser
......@@ -770,6 +797,7 @@ impl KvPushRouter {
context_id.to_string(),
&request.token_ids,
overlap_blocks,
expected_output_tokens,
worker,
)
.await;
......@@ -912,6 +940,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
// Route to worker
let isl_tokens = request.token_ids.len();
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
let track_output_blocks =
self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;
let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
......@@ -927,6 +963,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
// Output block tracking state
let mut cumulative_osl: usize = 0;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop {
tokio::select! {
biased;
......@@ -955,6 +995,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
// Track output blocks if enabled
if track_output_blocks {
let new_tokens = item.data.as_ref()
.map(|d| d.token_ids.len())
.unwrap_or(0);
cumulative_osl += new_tokens;
let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
// New block boundary crossed - add output block with decay
// Clamp eot to min 1 to avoid division by zero, and result to min 0.0
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
"Failed to add output block for request {context_id}: {e}"
);
}
current_total_blocks = new_total_blocks;
}
}
yield item;
}
}
......
......@@ -293,6 +293,7 @@ pub enum ActiveSequenceEventData {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
},
Free,
MarkPrefillCompleted,
......
......@@ -234,6 +234,7 @@ impl KvScheduler {
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
None, // expected_output_tokens not available in scheduler loop
selection.worker,
)
.await
......@@ -304,10 +305,18 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
) -> Result<(), SequenceError> {
self.slots
.add_request(request_id, token_sequence, isl, overlap, worker)
.add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
worker,
)
.await
}
......@@ -321,6 +330,16 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await
}
pub async fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
.await
}
pub async fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
......
......@@ -80,8 +80,16 @@ pub struct ActiveSequences {
prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// it contributes the fractional value instead of 1 to active_blocks()
fractional_blocks: HashMap<SequenceHash, f64>,
#[getter(copy)]
block_size: usize,
......@@ -104,7 +112,9 @@ impl ActiveSequences {
Self {
active_seqs: HashMap::new(),
prefill_tokens: HashMap::new(),
expected_output_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
block_size,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
......@@ -129,11 +139,37 @@ impl ActiveSequences {
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
}
pub fn active_blocks(&self) -> usize {
self.unique_blocks.len()
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value
count = count - 1.0 + frac;
}
}
count.round() as usize
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
if Rc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
}
}
}
/// Add a new request with its initial tokens
......@@ -144,6 +180,7 @@ impl ActiveSequences {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
) -> HashSet<RequestId> {
// Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) {
......@@ -159,6 +196,12 @@ impl ActiveSequences {
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
.iter()
......@@ -231,6 +274,9 @@ impl ActiveSequences {
self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
......@@ -249,6 +295,46 @@ impl ActiveSequences {
self.active_blocks()
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block(
&mut self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> bool {
// Check if request exists first (immutable borrow)
if !self.active_seqs.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block");
return false;
}
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
// Touch the block (adds to unique_blocks)
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
.get_mut(request_id)
.unwrap()
.push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac);
}
true
}
/// Force expiry of stale requests if the timer has elapsed
/// Returns the set of expired request IDs that were removed
pub fn force_expiry(&mut self) -> HashSet<RequestId> {
......@@ -279,6 +365,7 @@ enum UpdateSequences {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
},
Free {
......@@ -287,6 +374,11 @@ enum UpdateSequences {
MarkPrefillCompleted {
request_id: RequestId,
},
AddOutputBlock {
request_id: RequestId,
decay_fraction: Option<f64>,
resp_tx: tokio::sync::oneshot::Sender<bool>,
},
NewBlocks {
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>,
......@@ -428,9 +520,10 @@ impl ActiveSequencesMultiWorker {
token_sequence,
isl,
overlap,
expected_output_tokens,
resp_tx,
} => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap, expected_output_tokens);
let _ = resp_tx.send(removed);
}
UpdateSequences::Free { request_id } => {
......@@ -439,6 +532,14 @@ impl ActiveSequencesMultiWorker {
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::AddOutputBlock {
request_id,
decay_fraction,
resp_tx,
} => {
let success = active_sequences.add_output_block(&request_id, decay_fraction);
let _ = resp_tx.send(success);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
......@@ -535,6 +636,7 @@ impl ActiveSequencesMultiWorker {
token_sequence,
isl,
overlap,
expected_output_tokens,
} => {
request_to_worker.insert(event.request_id.clone(), event.worker);
......@@ -546,6 +648,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: token_sequence.clone(),
isl: *isl,
overlap: *overlap,
expected_output_tokens: *expected_output_tokens,
resp_tx,
});
} else {
......@@ -643,6 +746,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
) -> Result<(), SequenceError> {
// Check for worker existence
......@@ -670,6 +774,7 @@ impl ActiveSequencesMultiWorker {
token_sequence: token_sequence.clone(),
isl,
overlap,
expected_output_tokens,
},
router_id: self.router_id,
};
......@@ -689,6 +794,7 @@ impl ActiveSequencesMultiWorker {
token_sequence,
isl,
overlap,
expected_output_tokens,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
......@@ -804,6 +910,59 @@ impl ActiveSequencesMultiWorker {
Ok(())
}
/// Add an output block with optional fractional decay weight
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
pub async fn add_output_block(
&self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker
self.senders
.get(&worker)
.unwrap()
.send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(),
decay_fraction,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Wait for response
let success = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
if !success {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(),
});
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else {
......@@ -1028,15 +1187,15 @@ mod tests {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0);
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
......@@ -1112,6 +1271,7 @@ mod tests {
Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(0, 0),
)
.await?;
......@@ -1123,6 +1283,7 @@ mod tests {
Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(0, 1),
)
.await?;
......@@ -1134,6 +1295,7 @@ mod tests {
Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::new(1, 0),
)
.await?;
......@@ -1268,6 +1430,7 @@ mod tests {
None, // No token sequence
12, // ISL (12 tokens)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(0),
)
.await?;
......@@ -1279,6 +1442,7 @@ mod tests {
None, // No token sequence
8, // ISL (8 tokens)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(1),
)
.await?;
......@@ -1290,6 +1454,7 @@ mod tests {
None, // No token sequence
16, // ISL (16 tokens)
0, // no overlap
None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(2),
)
.await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment