test_attention_selector.py 5.28 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from unittest.mock import patch
4
5
6
7

import pytest
import torch

8
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
9
10
11
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
12
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
13
14


15
16
17
18
19
20
21
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


22
@pytest.mark.parametrize(
23
    "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
24
@pytest.mark.parametrize("use_v1", [True, False])
25
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
26
27
28
29
30
31
def test_env(
    name: str,
    use_v1: bool,
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
32
33
34
    """Test that the attention selector can be set via environment variable.
    Note that we do not test FlashAttn because it is the default backend.
    """
35

36
37
38
39
40
41
42
43
44
45
46
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, name)

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
            assert backend.get_name() == "TORCH_SDPA"
        elif device == "hip":
47
            with patch("vllm.attention.selector.current_platform",
48
                       RocmPlatform()):
49
50
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
51
            EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
52
            assert backend.get_name() == EXPECTED
53
54
55
56
57
58
59
60
        else:
            if name in ["XFORMERS", "FLASHINFER"]:
                with patch("vllm.attention.selector.current_platform",
                           CudaPlatform()):
                    backend = get_attn_backend(16, torch.float16,
                                               torch.float16, 16, False)
                EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
                assert backend.get_name() == EXPECTED
61
62


63
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
64
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
65
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
66
    # get_attn_backend
67

68
69
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
70

71
72
73
        # Unsupported CUDA arch
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda:
                            (7, 5))
74
75
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
76

77
78
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
79

80
81
82
        # Unsupported data type
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
83

84
85
86
        # Unsupported kv cache data type
        backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
87

88
89
90
91
92
93
94
95
        # Unsupported block size
        backend = get_attn_backend(16, torch.float16, None, 8, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
96
97
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
98

99
100
101
102
103
104
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
105

106
107
108
109
110
111
112
        # Unsupported head size
        backend = get_attn_backend(17, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # Attention-free models should bypass env and use PlaceholderAttention
        backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
113
114


115
@pytest.mark.parametrize("use_v1", [True, False])
116
117
118
119
120
121
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):

    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
122

123
        # Test with head size 32
124
        backend = get_attn_backend(32, torch.float16, None, 16, False)
125
126
        EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
        assert backend.get_name() == EXPECTED
127
128

        # when block size == 16, backend will fall back to XFORMERS
129
130
131
132
133
134
135
136
        # this behavior is not yet supported on V1.
        if use_v1:
            # TODO: support fallback on V1!
            # https://github.com/vllm-project/vllm/issues/14524
            pass
        else:
            backend = get_attn_backend(16, torch.float16, None, 16, False)
            assert backend.get_name() == "XFORMERS"