Unverified Commit 8b6e1d63 authored by Zzz9990's avatar Zzz9990 Committed by GitHub
Browse files

[Hardware][AMD] integrate aiter chunked prefill into vllm (#18596)


Signed-off-by: default avatarfsx950223 <fsx950223@outlook.com>
Signed-off-by: default avatarcharlifu <charlifu@amd.com>
Co-authored-by: default avatarfsx950223 <fsx950223@outlook.com>
Co-authored-by: default avatarcharlifu <charlifu@amd.com>
parent 735a9de7
...@@ -87,6 +87,7 @@ if TYPE_CHECKING: ...@@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
...@@ -653,6 +654,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -653,6 +654,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_MLA": "VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")), ("true", "1")),
# Whether to use aiter mha ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MHA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": "VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
......
...@@ -215,6 +215,12 @@ class RocmPlatform(Platform): ...@@ -215,6 +215,12 @@ class RocmPlatform(Platform):
selected_backend = _Backend.ROCM_FLASH selected_backend = _Backend.ROCM_FLASH
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_gfx9():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.") logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend") "triton_attn.TritonAttentionBackend")
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment