Unverified Commit fd618871 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

[Bugfix]: Fix ROCm fusion attn test; use AttentionBackend utils to create kv cache (#33948)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent 67a42b5a
......@@ -92,6 +92,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata."""
# TODO (Rohan138) reuse utils from vllm/v1/worker/gpu/attn_utils.py
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
......@@ -100,58 +102,31 @@ class AttentionQuantPatternModel(torch.nn.Module):
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
backend = self.attn.backend
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.block_size,
self.num_kv_heads,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
).permute(0, 1, 3, 2, 4)
else:
raise ValueError(f"Unsupported backend: {backend}")
# Fetch the attention backend and kv cache shape and stride order
attn_backend = self.attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment