Unverified Commit 7eb6cb6c authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update tests to remove deprecated env vars (#30563)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9ca8cb38
...@@ -39,7 +39,7 @@ docker run \ ...@@ -39,7 +39,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
cd tests cd tests
pytest -v -s v1/core pytest -v -s v1/core
pytest -v -s v1/engine pytest -v -s v1/engine
......
...@@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs( ...@@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
model: str, model: str,
backend: str, backend: str,
...@@ -77,9 +76,6 @@ def test_models( ...@@ -77,9 +76,6 @@ def test_models(
model_executor: str, model_executor: str,
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)
# 5042 tokens for gemma2 # 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096 # gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window # we need a prompt with more than 4096 tokens to test the sliding window
...@@ -104,6 +100,7 @@ def test_models( ...@@ -104,6 +100,7 @@ def test_models(
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=model_executor, distributed_executor_backend=model_executor,
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
...@@ -161,12 +158,6 @@ def test_models_distributed( ...@@ -161,12 +158,6 @@ def test_models_distributed(
): # noqa ): # noqa
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend:
monkeypatch_context.setenv(
"VLLM_ATTENTION_BACKEND",
attention_backend,
)
for k, v in extra_env.items(): for k, v in extra_env.items():
monkeypatch_context.setenv(k, v) monkeypatch_context.setenv(k, v)
...@@ -178,6 +169,7 @@ def test_models_distributed( ...@@ -178,6 +169,7 @@ def test_models_distributed(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method # will hurt multiprocessing backend with fork method
# (the default method). # (the default method).
attention_config = {"backend": attention_backend} if attention_backend else None
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
...@@ -185,6 +177,7 @@ def test_models_distributed( ...@@ -185,6 +177,7 @@ def test_models_distributed(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
......
...@@ -208,7 +208,8 @@ def test_attn_quant( ...@@ -208,7 +208,8 @@ def test_attn_quant(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
...@@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
...@@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp( ...@@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
......
...@@ -89,7 +89,6 @@ class TestSetting: ...@@ -89,7 +89,6 @@ class TestSetting:
], ],
) )
def test_compile_correctness( def test_compile_correctness(
monkeypatch: pytest.MonkeyPatch,
test_setting: TestSetting, test_setting: TestSetting,
): ):
# this test is run under multiple suits, with different GPUs. # this test is run under multiple suits, with different GPUs.
...@@ -107,8 +106,6 @@ def test_compile_correctness( ...@@ -107,8 +106,6 @@ def test_compile_correctness(
f"{cuda_device_count_stateless()}" f"{cuda_device_count_stateless()}"
) )
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [ final_args = [
*model_args, *model_args,
"-pp", "-pp",
...@@ -116,6 +113,7 @@ def test_compile_correctness( ...@@ -116,6 +113,7 @@ def test_compile_correctness(
"-tp", "-tp",
str(tp_size), str(tp_size),
"-cc.cudagraph_mode=none", "-cc.cudagraph_mode=none",
f"--attention-backend={attn_backend}",
] ]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
......
...@@ -74,7 +74,6 @@ def llm_pair(request): ...@@ -74,7 +74,6 @@ def llm_pair(request):
# Force native sampler to avoid potential nondeterminism in FlashInfer # Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1. # when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0", "VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
} }
with temporary_environ(env_vars): with temporary_environ(env_vars):
full = LLM( full = LLM(
...@@ -170,16 +169,10 @@ class TestFullCUDAGraph: ...@@ -170,16 +169,10 @@ class TestFullCUDAGraph:
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with (
temporary_environ(
{
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
# Flex_Attention is not supported with full cuda graph # Flex_Attention is not supported with full cuda graph
} with pytest.raises(RuntimeError):
),
pytest.raises(RuntimeError),
):
LLM( LLM(
model="Qwen/Qwen2-1.5B-Instruct", model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"), compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
) )
...@@ -197,20 +197,19 @@ def test_custom_compile_config( ...@@ -197,20 +197,19 @@ def test_custom_compile_config(
], ],
) )
def test_fp8_kv_scale_compile( def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int, compilation_mode: int,
model: str, model: str,
backend: AttentionBackendEnum | None, backend: AttentionBackendEnum | None,
): ):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs = { model_kwargs = {
"quantization": "fp8", "quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3", "kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True, "calculate_kv_scales": True,
"max_model_len": 512, "max_model_len": 512,
} }
if backend:
model_kwargs["attention_config"] = {"backend": backend.name}
run_model(compilation_mode, model, **model_kwargs) run_model(compilation_mode, model, **model_kwargs)
......
...@@ -219,14 +219,12 @@ def _test_cp_gsm8k( ...@@ -219,14 +219,12 @@ def _test_cp_gsm8k(
] ]
) )
server_env = {}
if attn_backend: if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend server_args.append(f"--attention-backend={attn_backend}")
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_id, model_id,
server_args, server_args,
env_dict=server_env,
max_wait_seconds=720, max_wait_seconds=720,
) as remote_server: ) as remote_server:
host = f"http://{remote_server.host}" host = f"http://{remote_server.host}"
......
...@@ -20,12 +20,10 @@ from ..utils import compare_two_settings, create_new_process_for_each_test ...@@ -20,12 +20,10 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch,
PP_SIZE: int, PP_SIZE: int,
MODEL_NAME: str, MODEL_NAME: str,
ATTN_BACKEND: LiteralString, ATTN_BACKEND: LiteralString,
): ):
with monkeypatch.context() as m:
cudagraph_args = [ cudagraph_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -34,8 +32,8 @@ def test_pp_cudagraph( ...@@ -34,8 +32,8 @@ def test_pp_cudagraph(
str(PP_SIZE), str(PP_SIZE),
"--distributed-executor-backend", "--distributed-executor-backend",
"mp", "mp",
f"--attention-backend={ATTN_BACKEND}",
] ]
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
eager_args = cudagraph_args + ["--enforce-eager"] eager_args = cudagraph_args + ["--enforce-eager"]
......
...@@ -9,7 +9,7 @@ from typing import Annotated, Literal ...@@ -9,7 +9,7 @@ from typing import Annotated, Literal
import pytest import pytest
from vllm.config import CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
EngineArgs, EngineArgs,
contains_type, contains_type,
...@@ -298,6 +298,139 @@ def test_compilation_config(): ...@@ -298,6 +298,139 @@ def test_compilation_config():
) )
def test_attention_config():
from vllm.attention.backends.registry import AttentionBackendEnum
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config == AttentionConfig()
# set backend via dot notation
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
# set backend via --attention-backend shorthand
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_backend is not None
assert engine_args.attention_backend == "FLASHINFER"
# set all fields via dot notation
args = parser.parse_args(
[
"--attention-config.backend",
"FLASH_ATTN",
"--attention-config.flash_attn_version",
"3",
"--attention-config.use_prefill_decode_attention",
"true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16",
"--attention-config.use_cudnn_prefill",
"true",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention",
"true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization",
"true",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_cudnn_prefill is True
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields
args = parser.parse_args(
[
"--attention-config="
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}',
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASHINFER"
assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
# test --attention-config.backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
# test --attention-backend and --attention-config.backend are mutually exclusive
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
with pytest.raises(ValueError, match="mutually exclusive"):
engine_args.create_engine_config()
def test_prefix_cache_default(): def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([]) args = parser.parse_args([])
......
...@@ -76,14 +76,9 @@ def default_server_args(with_tool_parser: bool): ...@@ -76,14 +76,9 @@ def default_server_args(with_tool_parser: bool):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gptoss_server( def gptoss_server(default_server_args: list[str]):
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
): with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
with RemoteOpenAIServer(
GPT_OSS_MODEL_NAME, default_server_args
) as remote_server:
yield remote_server yield remote_server
......
...@@ -6,7 +6,9 @@ from unittest.mock import patch ...@@ -6,7 +6,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
...@@ -73,18 +75,18 @@ def generate_params(): ...@@ -73,18 +75,18 @@ def generate_params():
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
def test_env( def test_backend_selection(
device: str, device: str,
name: str, name: str,
use_mla: bool, use_mla: bool,
block_size: int, block_size: int,
monkeypatch: pytest.MonkeyPatch,
): ):
"""Test attention backend selection with valid device-backend pairs.""" """Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m: # Create AttentionConfig with the specified backend
m.setenv("VLLM_ATTENTION_BACKEND", name) attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
...@@ -217,6 +219,10 @@ def test_env( ...@@ -217,6 +219,10 @@ def test_env(
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
# Use default config (no backend specified)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
...@@ -232,12 +238,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -232,12 +238,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
pytest.skip( pytest.skip(
"Skipping as current backend selector does not " "Skipping as current backend selector does not "
"handle fallbacks when a backend is set via env var." "handle fallbacks when a backend is explicitly set."
) )
with monkeypatch.context() as m: attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None, 16)
...@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
def test_invalid_env(monkeypatch: pytest.MonkeyPatch): def test_invalid_backend():
"""Test that invalid attention backend names raise ValueError.""" """Test that invalid attention backend names raise ValueError."""
with ( with (
monkeypatch.context() as m, pytest.raises(ValueError),
patch("vllm.platforms.current_platform", CudaPlatform()),
): ):
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID") # Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
...@@ -16,32 +18,41 @@ def clear_cache(): ...@@ -16,32 +18,41 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch): def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
# Set the current platform to ROCm using monkeypatch # Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
# Test standard ROCm attention # Test standard ROCm attention
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
# MLA test for deepseek related # MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
# change the attention backend to triton MLA with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# The selected backend is triton MLA # The selected backend is triton MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA # Change the attention backend to AITER MLA
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA") attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
...@@ -49,7 +60,14 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -49,7 +60,14 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true # If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled # If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA # The selected backend is ROCM_AITER_MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(
576, torch.bfloat16, "auto", 1, False, use_mla=True
)
assert backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
...@@ -37,7 +37,7 @@ def set_seed(seed): ...@@ -37,7 +37,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
...@@ -54,9 +54,6 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -54,9 +54,6 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
set_seed(seed) set_seed(seed)
with vllm_runner( with vllm_runner(
model_name, model_name,
...@@ -64,13 +61,13 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -64,13 +61,13 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
tensor_parallel_size=1, tensor_parallel_size=1,
num_gpu_blocks_override=128, num_gpu_blocks_override=128,
enforce_eager=True, enforce_eager=True,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_flex: ) as llm_flex:
output_flex = llm_flex.generate_greedy_logprobs( output_flex = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs prompts, max_tokens, num_logprobs
) )
# Run with default backend # Run with default backend
with monkeypatch.context() as m:
set_seed(seed) set_seed(seed)
with vllm_runner( with vllm_runner(
model_name, model_name,
...@@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_encoder_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
...@@ -110,8 +107,6 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -110,8 +107,6 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner( with vllm_runner(
model_name, model_name,
runner="pooling", runner="pooling",
...@@ -119,21 +114,19 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -119,21 +114,19 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
tensor_parallel_size=1, tensor_parallel_size=1,
max_model_len=100, max_model_len=100,
enforce_eager=True, enforce_eager=True,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_flex: ) as llm_flex:
flex_outputs = llm_flex.embed(prompts) flex_outputs = llm_flex.embed(prompts)
# Run with default backend # Run with default backend
with ( with vllm_runner(
monkeypatch.context() as m,
vllm_runner(
model_name, model_name,
runner="pooling", runner="pooling",
dtype=torch.bfloat16, dtype=torch.bfloat16,
tensor_parallel_size=1, tensor_parallel_size=1,
max_model_len=100, max_model_len=100,
enforce_eager=True, enforce_eager=True,
) as llm_default, ) as llm_default:
):
default_outputs = llm_default.embed(prompts) default_outputs = llm_default.embed(prompts)
check_embeddings_close( check_embeddings_close(
......
...@@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME ...@@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
models = [MODEL_NAME] models = [MODEL_NAME]
@pytest.fixture(autouse=True) @pytest.fixture
def set_attention_backend_for_rocm(monkeypatch): def granite_speech_attention_config():
"""Return attention config for Granite Speech tests on ROCm."""
if current_platform.is_rocm(): if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") return {"backend": "TRITON_ATTN"}
return None
def run_test( def run_test(
...@@ -53,6 +55,7 @@ def run_test( ...@@ -53,6 +55,7 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: str | None = None, distributed_executor_backend: str | None = None,
attention_config: dict | None = None,
): ):
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
...@@ -80,6 +83,7 @@ def run_test( ...@@ -80,6 +83,7 @@ def run_test(
enable_lora=True, enable_lora=True,
max_lora_rank=64, max_lora_rank=64,
enforce_eager=True, enforce_eager=True,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
lora_request = LoRARequest("audio", 1, audio_lora_path) lora_request = LoRARequest("audio", 1, audio_lora_path)
vllm_outputs_per_case = [ vllm_outputs_per_case = [
...@@ -131,6 +135,7 @@ def test_models( ...@@ -131,6 +135,7 @@ def test_models(
vllm_runner, vllm_runner,
model: str, model: str,
audio_assets: AudioTestAssets, audio_assets: AudioTestAssets,
granite_speech_attention_config,
dtype: str, dtype: str,
max_model_len: int, max_model_len: int,
max_tokens: int, max_tokens: int,
...@@ -157,4 +162,5 @@ def test_models( ...@@ -157,4 +162,5 @@ def test_models(
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
attention_config=granite_speech_attention_config,
) )
...@@ -2,23 +2,17 @@ ...@@ -2,23 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM pooling tests.""" """Pytest configuration for vLLM pooling tests."""
import os import pytest
import warnings
from vllm.platforms import current_platform from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items): @pytest.fixture
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" def siglip_attention_config():
if not current_platform.is_rocm(): """Return attention config for SigLIP tests on ROCm.
return
siglip_tests = [item for item in items if "test_siglip" in item.nodeid] On ROCm, SigLIP tests require FLEX_ATTENTION backend.
"""
if siglip_tests: if current_platform.is_rocm():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" return {"backend": "FLEX_ATTENTION"}
warnings.warn( return None
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
UserWarning,
stacklevel=1,
)
...@@ -38,6 +38,7 @@ def _run_test( ...@@ -38,6 +38,7 @@ def _run_test(
*, *,
dtype: str, dtype: str,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
attention_config: dict[str, Any] | None = None,
) -> None: ) -> None:
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
...@@ -49,6 +50,7 @@ def _run_test( ...@@ -49,6 +50,7 @@ def _run_test(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.embed( vllm_outputs = vllm_model.embed(
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
...@@ -90,6 +92,7 @@ def test_models_text( ...@@ -90,6 +92,7 @@ def test_models_text(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
...@@ -108,6 +111,7 @@ def test_models_text( ...@@ -108,6 +111,7 @@ def test_models_text(
"padding": "max_length", "padding": "max_length",
"max_length": 64, "max_length": 64,
}, # siglip2 was trained with this padding setting. }, # siglip2 was trained with this padding setting.
attention_config=siglip_attention_config,
) )
...@@ -117,6 +121,7 @@ def test_models_image( ...@@ -117,6 +121,7 @@ def test_models_image(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
...@@ -133,6 +138,7 @@ def test_models_image( ...@@ -133,6 +138,7 @@ def test_models_image(
input_images, input_images,
model, model,
dtype=dtype, dtype=dtype,
attention_config=siglip_attention_config,
) )
...@@ -141,6 +147,7 @@ def test_models_image( ...@@ -141,6 +147,7 @@ def test_models_image(
def test_models_text_image_no_crash( def test_models_text_image_no_crash(
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
...@@ -154,6 +161,7 @@ def test_models_text_image_no_crash( ...@@ -154,6 +161,7 @@ def test_models_text_image_no_crash(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=siglip_attention_config,
) as vllm_model: ) as vllm_model:
with pytest.raises(ValueError, match="not both"): with pytest.raises(ValueError, match="not both"):
vllm_model.embed(texts, images=images) vllm_model.embed(texts, images=images)
......
...@@ -75,7 +75,6 @@ def test_models( ...@@ -75,7 +75,6 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv("TOKENIZERS_PARALLELISM", "true")
m.setenv("VLLM_ATTENTION_BACKEND", backend)
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
...@@ -86,6 +85,7 @@ def test_models( ...@@ -86,6 +85,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype="auto", kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs( baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
...@@ -97,6 +97,7 @@ def test_models( ...@@ -97,6 +97,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs( test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
......
...@@ -108,11 +108,12 @@ def can_initialize( ...@@ -108,11 +108,12 @@ def can_initialize(
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,
): ):
if model_arch == "GptOssForCausalLM":
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3. # L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") attention_config = (
{"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
)
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
...@@ -143,6 +144,7 @@ def can_initialize( ...@@ -143,6 +144,7 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
) )
......
...@@ -94,26 +94,20 @@ def mock_on_gfx9(): ...@@ -94,26 +94,20 @@ def mock_on_gfx9():
None, None,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
), ),
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 # Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
(
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
None,
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"TRITON_ATTN", "TRITON_ATTN",
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled) # (explicitly disabled)
( (
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
None, None,
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"ROCM_ATTN", "ROCM_ATTN",
......
...@@ -249,8 +249,8 @@ def create_dummy_kv_cache( ...@@ -249,8 +249,8 @@ def create_dummy_kv_cache(
@dataclass @dataclass
class BackendConfig: class BackendConfig:
name: str name: str
env_vars: dict attention_config: dict
comp_config: dict # compilation config comp_config: dict
specific_gpu_arch: tuple | None = None specific_gpu_arch: tuple | None = None
...@@ -259,10 +259,10 @@ full_cg_backend_configs = { ...@@ -259,10 +259,10 @@ full_cg_backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": BackendConfig( "FA3": BackendConfig(
name="FA3", name="FA3",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3", "flash_attn_version": 3,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
...@@ -272,9 +272,7 @@ full_cg_backend_configs = { ...@@ -272,9 +272,7 @@ full_cg_backend_configs = {
# FlashMLA on Hopper # FlashMLA on Hopper
"FlashMLA": BackendConfig( "FlashMLA": BackendConfig(
name="FlashMLA", name="FlashMLA",
env_vars={ attention_config={"backend": "FLASHMLA"},
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -283,9 +281,7 @@ full_cg_backend_configs = { ...@@ -283,9 +281,7 @@ full_cg_backend_configs = {
# Cutlass MLA on Blackwell # Cutlass MLA on Blackwell
"CutlassMLA": BackendConfig( "CutlassMLA": BackendConfig(
name="CutlassMLA", name="CutlassMLA",
env_vars={ attention_config={"backend": "CUTLASS_MLA"},
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -294,9 +290,7 @@ full_cg_backend_configs = { ...@@ -294,9 +290,7 @@ full_cg_backend_configs = {
# FlashInfer MLA on Blackwell # FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig( "FlashInferMLA": BackendConfig(
name="FlashInferMLA", name="FlashInferMLA",
env_vars={ attention_config={"backend": "FLASHINFER_MLA"},
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -305,9 +299,9 @@ full_cg_backend_configs = { ...@@ -305,9 +299,9 @@ full_cg_backend_configs = {
# FlashAttention MLA on Hopper # FlashAttention MLA on Hopper
"FlashAttentionMLA": BackendConfig( "FlashAttentionMLA": BackendConfig(
name="FlashAttentionMLA", name="FlashAttentionMLA",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "backend": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
...@@ -317,10 +311,10 @@ full_cg_backend_configs = { ...@@ -317,10 +311,10 @@ full_cg_backend_configs = {
# FA2 # FA2
"FA2": BackendConfig( "FA2": BackendConfig(
name="FA2", name="FA2",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2", "flash_attn_version": 2,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
...@@ -329,7 +323,7 @@ full_cg_backend_configs = { ...@@ -329,7 +323,7 @@ full_cg_backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": BackendConfig( "TritonAttn": BackendConfig(
name="TritonAttn", name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, attention_config={"backend": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -337,14 +331,17 @@ full_cg_backend_configs = { ...@@ -337,14 +331,17 @@ full_cg_backend_configs = {
# FlashInfer # FlashInfer
"FlashInfer": BackendConfig( "FlashInfer": BackendConfig(
name="FlashInfer", name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, attention_config={"backend": "FLASHINFER"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
), ),
"RocmAttn": BackendConfig( "RocmAttn": BackendConfig(
name="RocmAttn", name="RocmAttn",
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, attention_config={
"backend": "ROCM_ATTN",
"use_prefill_decode_attention": True,
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment