test_attention_selector.py 3.6 KB
Newer Older
1
2
3
4
5
from unittest.mock import patch

import pytest
import torch

6
from tests.kernels.utils import override_backend_env_variable
7
from vllm.attention.selector import which_attn_to_use
8
9
10
11
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
12
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
13
14
15


@pytest.mark.parametrize(
16
17
    "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
18
def test_env(name: str, device: str, monkeypatch):
19
20
21
    """Test that the attention selector can be set via environment variable.
    Note that we do not test FlashAttn because it is the default backend.
    """
22
23

    override_backend_env_variable(monkeypatch, name)
24
25

    if device == "cpu":
26
        with patch("vllm.attention.selector.current_platform", CpuPlatform()):
27
28
            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                        False)
29
30
        assert backend.name == "TORCH_SDPA"
    elif device == "hip":
31
        with patch("vllm.attention.selector.current_platform", RocmPlatform()):
32
33
            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                        False)
34
        assert backend.name == "ROCM_FLASH"
35
    elif device == "openvino":
36
        with patch("vllm.attention.selector.current_platform",
37
                   OpenVinoPlatform()):
38
39
            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                        False)
40
        assert backend.name == "OPENVINO"
41
    else:
42
        with patch("vllm.attention.selector.current_platform", CudaPlatform()):
43
44
            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                        False)
45
46
47
        assert backend.name == name


48
def test_flash_attn(monkeypatch):
49
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
50
51
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
    # which_attn_to_use
52
53

    override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
54
55

    # Unsupported CUDA arch
56
    with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
57
        backend = which_attn_to_use(16, torch.float16, None, 16, False)
58
        assert backend.name != STR_FLASH_ATTN_VAL
59
60

    # Unsupported data type
61
    backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
62
    assert backend.name != STR_FLASH_ATTN_VAL
63
64

    # Unsupported kv cache data type
65
    backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
66
    assert backend.name != STR_FLASH_ATTN_VAL
67
68

    # Unsupported block size
69
    backend = which_attn_to_use(16, torch.float16, None, 8, False)
70
    assert backend.name != STR_FLASH_ATTN_VAL
71
72
73

    # flash-attn is not installed
    with patch.dict('sys.modules', {'vllm_flash_attn': None}):
74
        backend = which_attn_to_use(16, torch.float16, None, 16, False)
75
        assert backend.name != STR_FLASH_ATTN_VAL
76
77

    # Unsupported head size
78
    backend = which_attn_to_use(17, torch.float16, None, 16, False)
79
80
81
    assert backend.name != STR_FLASH_ATTN_VAL

    # Attention-free models should bypass env and use PlaceholderAttention
82
    backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
83
    assert backend.name != STR_FLASH_ATTN_VAL
84
85


86
def test_invalid_env(monkeypatch):
87
    """Throw an exception if the backend name is invalid."""
88
    override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
89
    with pytest.raises(ValueError):
90
        which_attn_to_use(16, torch.float16, None, 16, False)