Unverified Commit c7a79d41 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention][3/n] Remove usage of deprecated `seq_lens_cpu` and...


[Attention][3/n] Remove usage of deprecated `seq_lens_cpu` and `num_computed_tokens_cpu` CommonAttentionMetadata properties (#31850)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 6409004b
...@@ -337,7 +337,7 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -337,7 +337,7 @@ class AiterFlashAttentionMetadataBuilder(
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
...@@ -367,7 +367,7 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -367,7 +367,7 @@ class AiterFlashAttentionMetadataBuilder(
if num_extends > 0: if num_extends > 0:
num_extends_slice = slice(num_decodes, num_decodes + num_extends) num_extends_slice = slice(num_decodes, num_decodes + num_extends)
query_lens_for_extend = query_lens_cpu[num_extends_slice] query_lens_for_extend = query_lens_cpu[num_extends_slice]
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice] seq_lens_for_extend = seq_lens[num_extends_slice]
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
swa_metadata = None swa_metadata = None
if self.aot_sliding_window is not None: if self.aot_sliding_window is not None:
......
...@@ -124,7 +124,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat ...@@ -124,7 +124,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
prefix_kv_lens = torch.tensor( prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device [common_prefix_len], dtype=torch.int32, device=self.device
) )
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device) suffix_kv_lens = suffix_kv_lens.to(self.device)
else: else:
cu_prefix_query_lens = None cu_prefix_query_lens = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment