test_attention_selector.py 5.38 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from unittest.mock import patch
4
5
6
7

import pytest
import torch

8
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
9
10
11
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
zhuwenwen's avatar
zhuwenwen committed
12

13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
zhuwenwen's avatar
zhuwenwen committed
14
from vllm.platforms import current_platform
15
16


17
18
19
20
21
22
23
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


24
@pytest.mark.parametrize(
zhuwenwen's avatar
zhuwenwen committed
25
    "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"] if not current_platform.is_rocm() else ["ROCM_FLASH"])
26
@pytest.mark.parametrize("use_v1", [True, False])
27
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
28
29
30
31
32
33
def test_env(
    name: str,
    use_v1: bool,
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
34
35
36
    """Test that the attention selector can be set via environment variable.
    Note that we do not test FlashAttn because it is the default backend.
    """
37

38
39
40
41
42
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, name)

        if device == "cpu":
43
            with patch("vllm.attention.selector.current_platform",
44
                       CpuPlatform()):
45
46
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
47
48
            assert backend.get_name() == "TORCH_SDPA"
        elif device == "hip":
49
            with patch("vllm.attention.selector.current_platform",
50
                       RocmPlatform()):
51
52
                backend = get_attn_backend(16, torch.float16, torch.float16,
                                           16, False)
53
            EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
54
            assert backend.get_name() == EXPECTED
55
56
57
58
59
60
61
62
        else:
            if name in ["XFORMERS", "FLASHINFER"]:
                with patch("vllm.attention.selector.current_platform",
                           CudaPlatform()):
                    backend = get_attn_backend(16, torch.float16,
                                               torch.float16, 16, False)
                EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
                assert backend.get_name() == EXPECTED
63
64


65
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
66
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
67
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
68
    # get_attn_backend
69

70
71
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
72

73
74
75
        # Unsupported CUDA arch
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda:
                            (7, 5))
76
77
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
78

79
80
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
81

82
83
84
        # Unsupported data type
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
85

86
87
88
        # Unsupported kv cache data type
        backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
89

90
91
92
93
94
95
96
97
        # Unsupported block size
        backend = get_attn_backend(16, torch.float16, None, 8, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
98
99
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
100

101
102
103
104
105
106
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
107

108
109
110
111
112
113
114
        # Unsupported head size
        backend = get_attn_backend(17, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # Attention-free models should bypass env and use PlaceholderAttention
        backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
115

116

117
@pytest.mark.parametrize("use_v1", [True, False])
118
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
119

120
121
122
123
    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
124

125
        # Test with head size 32
126
        backend = get_attn_backend(32, torch.float16, None, 16, False)
127
128
        EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
        assert backend.get_name() == EXPECTED
129
130

        # when block size == 16, backend will fall back to XFORMERS
131
132
133
134
135
136
137
138
        # this behavior is not yet supported on V1.
        if use_v1:
            # TODO: support fallback on V1!
            # https://github.com/vllm-project/vllm/issues/14524
            pass
        else:
            backend = get_attn_backend(16, torch.float16, None, 16, False)
            assert backend.get_name() == "XFORMERS"