Unverified Commit 4c73be14 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention][2/n] Remove usage of deprecated `seq_lens_cpu` and...


[Attention][2/n] Remove usage of deprecated `seq_lens_cpu` and `num_computed_tokens_cpu` CommonAttentionMetadata properties (#31774)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 2f4bdee6
...@@ -142,8 +142,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -142,8 +142,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
m = common_attn_metadata m = common_attn_metadata
query_start_loc = m.query_start_loc query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu context_lens_tensor = m.compute_num_computed_tokens()
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if ( if (
...@@ -370,6 +369,5 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -370,6 +369,5 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_accepted_tokens = torch.diff(m.query_start_loc) num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
...@@ -215,7 +215,10 @@ class Mamba2AttentionMetadataBuilder( ...@@ -215,7 +215,10 @@ class Mamba2AttentionMetadataBuilder(
num_prefills = common.num_prefills num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens num_decode_tokens = common.num_decode_tokens
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs num_reqs - num_prefills : num_reqs
] ]
query_start_loc_p_cpu = ( query_start_loc_p_cpu = (
......
...@@ -138,9 +138,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -138,9 +138,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int, mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
self.device
)
# Block index of the last computed token # Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token # which is <= block index for the first scheduled token
...@@ -193,13 +191,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -193,13 +191,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching: if self.vllm_config.cache_config.enable_prefix_caching:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks) # Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables: # Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
( (
block_idx_last_computed_token, block_idx_last_computed_token,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
...@@ -212,15 +209,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -212,15 +209,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
if num_prefills > 0: if num_prefills > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
num_computed_tokens_cpu = num_computed_tokens.cpu()
query_start_loc_p = ( query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :] common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens - num_decode_tokens
) )
has_initial_states_cpu = ( has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[ num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
num_reqs - num_prefills : num_reqs
]
> 0
) )
has_initial_states_p = has_initial_states_cpu.to( has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device common_attn_metadata.query_start_loc.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment