"vscode:/vscode.git/clone" did not exist on "c0a350ca73fa8f10bd718fdcec47b075dc1d102f"
Unverified Commit b28246f6 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 3b5567a2
...@@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend): ...@@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend):
def use_cascade_attention(*args, **kwargs) -> bool: def use_cascade_attention(*args, **kwargs) -> bool:
return False return False
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
class ROCmAttentionImpl(AttentionImpl): class ROCmAttentionImpl(AttentionImpl):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment