Unverified Commit fd618871 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

[Bugfix]: Fix ROCm fusion attn test; use AttentionBackend utils to create kv cache (#33948)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent 67a42b5a
...@@ -92,6 +92,8 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -92,6 +92,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata.""" """Initialize attention metadata."""
# TODO (Rohan138) reuse utils from vllm/v1/worker/gpu/attn_utils.py
# Create common attn metadata # Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata( common_attn_metadata = create_common_attn_metadata(
...@@ -100,58 +102,31 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -100,58 +102,31 @@ class AttentionQuantPatternModel(torch.nn.Module):
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
backend = self.attn.backend
# Fetch the attention backend and kv cache shape and stride order
# TODO(luka) use get_kv_cache_stride_order attn_backend = self.attn.attn_backend
# Create dummy KV cache for the selected backend kv_cache_shape = attn_backend.get_kv_cache_shape(
if backend == AttentionBackendEnum.ROCM_ATTN: num_blocks, self.block_size, self.num_kv_heads, self.head_size
# k/v as 1st dimention )
# HND: [num_blocks, num_kv_heads, block_size, head_size] try:
kv_cache = torch.zeros( kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
2, except (AttributeError, NotImplementedError):
num_blocks, kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
self.num_kv_heads,
self.block_size, kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
self.head_size, inv_order = [
dtype=self.kv_cache_dtype, kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
device=self.device, ]
)
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # Create dummy KV cache
# k/v as 1st dimention raw_tensor = torch.zeros(
# NHD: [num_blocks, block_size, num_kv_heads, head_size] 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
kv_cache = torch.zeros( dtype=self.kv_cache_dtype,
2, device=self.device,
num_blocks, )
self.block_size, raw_tensor = raw_tensor.view(kv_cache_shape)
self.num_kv_heads, kv_cache = raw_tensor.permute(*inv_order)
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
).permute(0, 1, 3, 2, 4)
else:
raise ValueError(f"Unsupported backend: {backend}")
self.attn.kv_cache = [kv_cache] self.attn.kv_cache = [kv_cache]
# Build attn metadata # Build attn metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment