test_attention_selector.py 5.78 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from unittest.mock import Mock, patch
4
5
6
7

import pytest
import torch

8
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
9
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
18
19
20
21
22
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


23
@pytest.mark.parametrize(
24
    "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
25
@pytest.mark.parametrize("use_v1", [True, False])
26
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
27
28
29
30
31
32
def test_env(
    name: str,
    use_v1: bool,
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
33
34
35
    """Test that the attention selector can be set via environment variable.
    Note that we do not test FlashAttn because it is the default backend.
    """
36

37
38
39
40
41
42
43
44
45
46
47
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, name)

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
            assert backend.get_name() == "TORCH_SDPA"
        elif device == "hip":
48
            with patch("vllm.attention.selector.current_platform",
49
                       RocmPlatform()):
50
51
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
52
            EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
53
            assert backend.get_name() == EXPECTED
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        elif device == "openvino":
            with patch("vllm.attention.selector.current_platform",
                       OpenVinoPlatform()), patch.dict('sys.modules',
                                                       {'openvino': Mock()}):
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
            assert backend.get_name() == "OPENVINO"
        else:
            if name in ["XFORMERS", "FLASHINFER"]:
                with patch("vllm.attention.selector.current_platform",
                           CudaPlatform()):
                    backend = get_attn_backend(16, torch.float16,
                                               torch.float16, 16, False)
                EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
                assert backend.get_name() == EXPECTED
69
70


71
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
72
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
73
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
74
    # get_attn_backend
75

76
77
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
78

79
80
81
        # Unsupported CUDA arch
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda:
                            (7, 5))
82
83
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
84

85
86
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
87

88
89
90
        # Unsupported data type
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
91

92
93
94
        # Unsupported kv cache data type
        backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
95

96
97
98
99
100
101
102
103
        # Unsupported block size
        backend = get_attn_backend(16, torch.float16, None, 8, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
104
105
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
106

107
108
109
110
111
112
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
113

114
115
116
117
118
119
120
        # Unsupported head size
        backend = get_attn_backend(17, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # Attention-free models should bypass env and use PlaceholderAttention
        backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
121
122


123
@pytest.mark.parametrize("use_v1", [True, False])
124
125
126
127
128
129
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):

    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
130

131
        # Test with head size 32
132
        backend = get_attn_backend(32, torch.float16, None, 16, False)
133
134
        EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
        assert backend.get_name() == EXPECTED
135
136

        # when block size == 16, backend will fall back to XFORMERS
137
138
139
140
141
142
143
144
        # this behavior is not yet supported on V1.
        if use_v1:
            # TODO: support fallback on V1!
            # https://github.com/vllm-project/vllm/issues/14524
            pass
        else:
            backend = get_attn_backend(16, torch.float16, None, 16, False)
            assert backend.get_name() == "XFORMERS"