Unverified Commit 66e674cd authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent dff0a2b3
...@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "Set --attention-config.backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )
......
...@@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
...@@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}." f"num_heads: {num_heads}."
) )
self.supports_quant_query_input = current_platform.is_cuda()
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment