Unverified Commit 42e95479 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[ROCm][Test] Fix ROCM_AITER_UNIFIED_ATTN attn+quant fusion test (#37640)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent a32783bb
...@@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
device: torch.device, device: torch.device,
vllm_config: VllmConfig, vllm_config: VllmConfig,
block_size: int,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.attn._k_scale = self.attn._k_scale.to(device) self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device) self.attn._v_scale = self.attn._v_scale.to(device)
self.block_size = 16 self.block_size = block_size
# Initialize attn MetadataBuilder # Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()( self.builder = self.attn.attn_backend.get_builder_cls()(
...@@ -299,6 +300,9 @@ def test_attention_quant_pattern( ...@@ -299,6 +300,9 @@ def test_attention_quant_pattern(
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
backend_cls = backend.get_class()
block_size = backend_cls.get_preferred_block_size(16)
model_config = ModelConfig( model_config = ModelConfig(
model=model_name, model=model_name,
max_model_len=2048, max_model_len=2048,
...@@ -342,6 +346,7 @@ def test_attention_quant_pattern( ...@@ -342,6 +346,7 @@ def test_attention_quant_pattern(
kv_cache_dtype=FP8_DTYPE, kv_cache_dtype=FP8_DTYPE,
device=device, device=device,
vllm_config=vllm_config_unfused, vllm_config=vllm_config_unfused,
block_size=block_size,
) )
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044 result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
...@@ -370,6 +375,7 @@ def test_attention_quant_pattern( ...@@ -370,6 +375,7 @@ def test_attention_quant_pattern(
device=device, device=device,
vllm_config=vllm_config, vllm_config=vllm_config,
w=model_unfused.w, w=model_unfused.w,
block_size=block_size,
) )
model_fused = model_fused.to(device) model_fused = model_fused.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment