Unverified Commit 42e95479 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[ROCm][Test] Fix ROCM_AITER_UNIFIED_ATTN attn+quant fusion test (#37640)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent a32783bb
......@@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
block_size: int,
**kwargs,
):
super().__init__()
......@@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
self.block_size = 16
self.block_size = block_size
# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
......@@ -299,6 +300,9 @@ def test_attention_quant_pattern(
torch.set_default_dtype(dtype)
torch.manual_seed(42)
backend_cls = backend.get_class()
block_size = backend_cls.get_preferred_block_size(16)
model_config = ModelConfig(
model=model_name,
max_model_len=2048,
......@@ -342,6 +346,7 @@ def test_attention_quant_pattern(
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
block_size=block_size,
)
model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
......@@ -370,6 +375,7 @@ def test_attention_quant_pattern(
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
block_size=block_size,
)
model_fused = model_fused.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment