Unverified Commit 713e9e48 authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

fix: DIS-706 skip offloading the G1 matched blocks during offloading (#3299)


Signed-off-by: default avatarrichardhuo-nv <rihuo@nvidia.com>
parent 836d7417
...@@ -115,6 +115,7 @@ pub trait Slot: std::fmt::Debug { ...@@ -115,6 +115,7 @@ pub trait Slot: std::fmt::Debug {
tokens: &[u32], tokens: &[u32],
block_ids: &[usize], block_ids: &[usize],
computed_position: usize, computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError>; ) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
...@@ -592,6 +593,7 @@ impl Slot for VllmConnectorSlot { ...@@ -592,6 +593,7 @@ impl Slot for VllmConnectorSlot {
tokens: &[u32], tokens: &[u32],
block_ids: &[usize], block_ids: &[usize],
computed_position: usize, computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError> { ) -> Result<(), SlotError> {
// TRTLLM's KV Connector Manager will have (computed_position - external matches) // TRTLLM's KV Connector Manager will have (computed_position - external matches)
// in onborading case // in onborading case
...@@ -630,10 +632,21 @@ impl Slot for VllmConnectorSlot { ...@@ -630,10 +632,21 @@ impl Slot for VllmConnectorSlot {
self.device_blocks.extend(block_ids); self.device_blocks.extend(block_ids);
} }
// This approach is fragile, but it’s the only way currently to skip evaluating
// the device matched blocks and to avoid offloading them again.
// TODO: Consider adding an indicator in the scheduler output to distinguish between
// matched and unmatched device blocks/tokens from the scheduler.
let maybe_have_device_matched_blocks =
is_new_request && computed_position > 0 && self.evaluated_blocks == 0;
if maybe_have_device_matched_blocks {
self.evaluated_blocks = (computed_position + 1) / self.block_size;
}
let num_candidate_blocks = let num_candidate_blocks =
((computed_position + 1) / self.block_size) - self.evaluated_blocks; ((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks);
if num_candidate_blocks != 0 { if num_candidate_blocks > 0 {
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet. // do we have a mechanism for skipping gpu cache hit blocks? not sure yet.
// for now, offload all the blocks to the host // for now, offload all the blocks to the host
let offload_block_ids: Vec<usize> = self let offload_block_ids: Vec<usize> = self
......
...@@ -334,6 +334,7 @@ impl Leader for KvConnectorLeader { ...@@ -334,6 +334,7 @@ impl Leader for KvConnectorLeader {
&new_req.prompt_token_ids, &new_req.prompt_token_ids,
&new_req.block_ids, &new_req.block_ids,
new_req.num_computed_tokens, new_req.num_computed_tokens,
true,
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
...@@ -364,6 +365,7 @@ impl Leader for KvConnectorLeader { ...@@ -364,6 +365,7 @@ impl Leader for KvConnectorLeader {
&cached_req.new_token_ids, &cached_req.new_token_ids,
&cached_req.new_block_ids, &cached_req.new_block_ids,
cached_req.num_computed_tokens, cached_req.num_computed_tokens,
false,
)?; )?;
if let Some(pending_ops) = slot.take_pending_operations() { if let Some(pending_ops) = slot.take_pending_operations() {
......
...@@ -739,7 +739,7 @@ mod tests { ...@@ -739,7 +739,7 @@ mod tests {
let disk_pool = if let Some(disk_blocks) = disk_blocks { let disk_pool = if let Some(disk_blocks) = disk_blocks {
config.num_blocks = disk_blocks; config.num_blocks = disk_blocks;
Some(build_layout( Some(build_layout(
config, config.clone(),
layout_type, layout_type,
agent, agent,
&DiskAllocator, &DiskAllocator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment