Unverified Commit 496e991d authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Doc] Consistent naming of attention backends (#9498)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 696b01af
...@@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "flash-attn" return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_impl_cls() -> Type["FlashAttentionImpl"]:
......
...@@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "flashinfer" return "FLASHINFER"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]: def get_impl_cls() -> Type["FlashInferImpl"]:
......
...@@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend): ...@@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ipex-attn" return "IPEX"
@staticmethod @staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]: def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
......
...@@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend): ...@@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "openvino" return "OPENVINO"
@staticmethod @staticmethod
def get_impl_cls(): def get_impl_cls():
......
...@@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState ...@@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod @staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl return PallasAttentionBackendImpl
......
...@@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend): ...@@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "placeholder-attn" return "NO_ATTENTION"
@staticmethod @staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
......
...@@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "rocm-flash-attn" return "ROCM_FLASH"
@staticmethod @staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
......
...@@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "torch-sdpa" return "TORCH_SDPA"
@staticmethod @staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]: def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
......
...@@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState): ...@@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers backend.
# Assert the same. # Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \ assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'xformers', but "\ f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'" f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model( self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata) batch_size=batch_size, attn_metadata=attn_metadata)
...@@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState): ...@@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers backend.
# Assert the same. # Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \ assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'xformers', but "\ f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'" f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model( self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers) attn_metadata=attn_metadata, input_buffers=input_buffers)
...@@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState): ...@@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model: if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend. # The encoder decoder model works only with XFormers backend.
# Assert the same. # Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \ assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'xformers', but "\ f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'" f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model( self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers) attn_metadata, input_buffers)
......
...@@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend): ...@@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "xformers" return "XFORMERS"
@staticmethod @staticmethod
def get_impl_cls() -> Type["XFormersImpl"]: def get_impl_cls() -> Type["XFormersImpl"]:
......
...@@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner):
return False return False
# TODO: Add support for other attn backends # TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn": if self.attn_backend.get_name() != "FLASH_ATTN":
return False return False
# TODO: Add support for LORA # TODO: Add support for LORA
......
...@@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if not disable_mqa_scorer: if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name( if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn": ) != "FLASH_ATTN":
disable_mqa_scorer = True disable_mqa_scorer = True
logger.info( logger.info(
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
......
...@@ -1855,7 +1855,7 @@ class CUDAGraphRunner(nn.Module): ...@@ -1855,7 +1855,7 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "placeholder-attn": if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_( self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True) attn_metadata.slot_mapping, non_blocking=True)
......
...@@ -29,8 +29,8 @@ if TYPE_CHECKING: ...@@ -29,8 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]: -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment