Unverified Commit e0327c9d authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention][1/n] Remove usage of deprecated `seq_lens_cpu` and...


[Attention][1/n] Remove usage of deprecated `seq_lens_cpu` and `num_computed_tokens_cpu` CommonAttentionMetadata properties (#31773)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 14df02b4
...@@ -126,12 +126,12 @@ def create_and_prepopulate_kv_cache( ...@@ -126,12 +126,12 @@ def create_and_prepopulate_kv_cache(
Tuple of (kv_cache, updated_block_table) Tuple of (kv_cache, updated_block_table)
""" """
batch_size = len(k_contexts) batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = ( query_lens = (
common_attn_metadata.query_start_loc_cpu[1:] common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1] - common_attn_metadata.query_start_loc_cpu[:-1]
) )
context_lens = common_attn_metadata.num_computed_tokens_cpu context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
......
...@@ -154,12 +154,12 @@ def create_and_prepopulate_kv_cache( ...@@ -154,12 +154,12 @@ def create_and_prepopulate_kv_cache(
MLA KV cache tensor MLA KV cache tensor
""" """
batch_size = len(kv_c_contexts) batch_size = len(kv_c_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = ( query_lens = (
common_attn_metadata.query_start_loc_cpu[1:] common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1] - common_attn_metadata.query_start_loc_cpu[:-1]
) )
context_lens = common_attn_metadata.num_computed_tokens_cpu context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
......
...@@ -297,7 +297,7 @@ def test_sparse_backend_decode_correctness( ...@@ -297,7 +297,7 @@ def test_sparse_backend_decode_correctness(
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths starts[:-1], seg_lengths
) )
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths) positions += np.repeat(prefix_lengths, seg_lengths)
......
...@@ -870,7 +870,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -870,7 +870,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Guard access to seq_lens_cpu, which may not always be needed # Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode. # and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None seq_lens_cpu = (
common_attn_metadata.seq_lens.cpu() if needs_seq_lens_cpu else None
)
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = ( num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size (seq_lens_np + (page_size - 1)) // page_size
......
...@@ -727,9 +727,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -727,9 +727,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
block_table_tensor, seq_lens, block_size, num_gpu_blocks block_table_tensor, seq_lens, block_size, num_gpu_blocks
) )
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( offset_tensor = common_attn_metadata.compute_num_computed_tokens()
self.device, non_blocking=True
)
out = FlexAttentionMetadata( out = FlexAttentionMetadata(
causal=common_attn_metadata.causal, causal=common_attn_metadata.causal,
......
...@@ -791,7 +791,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -791,7 +791,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None prefill_metadata = None
if num_prefills > 0: if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
......
...@@ -511,7 +511,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -511,7 +511,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# For pure decode batches, prefill_request_id will be None # For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill # For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0: if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......
...@@ -221,7 +221,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet ...@@ -221,7 +221,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
prefix_kv_lens = torch.tensor( prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device [common_prefix_len], dtype=torch.int32, device=self.device
) )
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device) suffix_kv_lens = suffix_kv_lens.to(self.device)
else: else:
cu_prefix_query_lens = None cu_prefix_query_lens = None
......
...@@ -100,6 +100,8 @@ class CommonAttentionMetadata: ...@@ -100,6 +100,8 @@ class CommonAttentionMetadata:
_seq_lens_cpu: torch.Tensor | None = None _seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property @property
@deprecated( @deprecated(
""" """
...@@ -130,6 +132,13 @@ class CommonAttentionMetadata: ...@@ -130,6 +132,13 @@ class CommonAttentionMetadata:
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support # TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded( def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int self, num_actual_tokens: int, num_actual_reqs: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment