Unverified Commit c07e2ca6 authored by Ming Yang's avatar Ming Yang Committed by GitHub
Browse files

Fix Mamba state corruption from referencing stale block table entries (#37728) (#37728) (#37728)

parent 4df5fa74
...@@ -121,6 +121,12 @@ class BlockTable: ...@@ -121,6 +121,12 @@ class BlockTable:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
def clear_row(self, row_idx: int) -> None:
num_blocks = self.num_blocks_per_row[row_idx]
if num_blocks > 0:
self.block_table.np[row_idx, :num_blocks] = 0
self.num_blocks_per_row[row_idx] = 0
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src] num_blocks = self.num_blocks_per_row[src]
block_table_np = self.block_table.np block_table_np = self.block_table.np
...@@ -275,6 +281,10 @@ class MultiGroupBlockTable: ...@@ -275,6 +281,10 @@ class MultiGroupBlockTable:
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx) block_table.add_row(block_ids[i], row_idx)
def clear_row(self, row_idx: int) -> None:
for block_table in self.block_tables:
block_table.clear_row(row_idx)
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.move_row(src, tgt) block_table.move_row(src, tgt)
......
...@@ -496,6 +496,7 @@ class InputBatch: ...@@ -496,6 +496,7 @@ class InputBatch:
self._req_ids[req_index] = None self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear() self.spec_token_ids[req_index].clear()
self.block_table.clear_row(req_index)
# LoRA # LoRA
lora_id = self.request_lora_mapping[req_index] lora_id = self.request_lora_mapping[req_index]
......
...@@ -5376,6 +5376,12 @@ class GPUModelRunner( ...@@ -5376,6 +5376,12 @@ class GPUModelRunner(
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
# Sync block table CPU->GPU so cleared rows from
# remove_request() are visible to the attention metadata
# builder. Without this, stale block IDs from finished
# requests can corrupt Mamba state.
self.input_batch.block_table.commit_block_table(num_reqs_padded)
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata( attn_metadata, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment