Unverified Commit 825c2dc1 authored by Kevin McKay's avatar Kevin McKay Committed by GitHub
Browse files

[Bugfix][Hardware][AMD] Fix last_page_len calculation in AITER MLA decode (#31282)


Signed-off-by: default avatarc0de128 <kevin.mckay@outlook.com>
parent 1f43c121
...@@ -88,6 +88,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -88,6 +88,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO: we can disambiguate between decode and mixed-prefill decode here # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually # so we can only use the persistent buffer if a cudagraph is actually
# being used. # being used.
# paged_kv_last_page_len is always 1s (kernel block size is always 1),
# so we create it once and reuse slices in both eager and cudagraph modes.
self.paged_kv_last_page_len = torch.ones(
max_num_reqs, dtype=torch.int32, device=device
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros( self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
...@@ -95,9 +102,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -95,9 +102,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_indices = torch.zeros( self.paged_kv_indices = torch.zeros(
max_num_pages, dtype=torch.int32, device=device max_num_pages, dtype=torch.int32, device=device
) )
self.paged_kv_last_page_len = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)
self.qo_indptr = torch.zeros( self.qo_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
...@@ -122,7 +126,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -122,7 +126,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
).unsqueeze(0) < seq_lens_device.unsqueeze(1) ).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask] paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device) # kernel block size is always 1, so each page has exactly 1 token.
# last_page_len is always 1 - just slice the pre-initialized buffer.
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
paged_kv_indptr = torch.cat( paged_kv_indptr = torch.cat(
[ [
...@@ -148,11 +154,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -148,11 +154,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
self.paged_kv_last_page_len[:num_reqs].copy_( # paged_kv_last_page_len already uses the pre-initialized buffer slice
paged_kv_last_page_len, non_blocking=True # (set above), so no copy needed - buffer is always 1s.
)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
self.qo_indptr[: 1 + num_reqs].copy_( self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True query_start_loc_device, non_blocking=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment