test_use_trtllm_attention.py 6.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import patch

import pytest
import torch

from vllm.utils.flashinfer import (
    can_use_trtllm_attention,
    supports_trtllm_attention,
    use_trtllm_attention,
)

MODEL_CONFIGS = {
    "Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8),
    "Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8),
    "Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2),
    "Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8),
    "Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4),
    "Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8),
}


def get_config(model: str) -> dict:
    """Return the attention config for a model."""
    return MODEL_CONFIGS[model]


DEFAULT_KWARGS = dict(
    **get_config("Llama-3-70B"),
    num_tokens=128,
    max_seq_len=4096,
    dcp_world_size=1,
    kv_cache_dtype="auto",
    q_dtype=torch.bfloat16,
    is_prefill=False,
    force_use_trtllm=None,
    has_sinks=False,
    has_spec=False,
)


def _call(**overrides) -> bool:
    kwargs = {**DEFAULT_KWARGS, **overrides}
    return use_trtllm_attention(**kwargs)


@pytest.fixture(autouse=True)
def _clear_supports_cache():
    """Clear functools.cache to ensure each test runs independently."""
    supports_trtllm_attention.cache_clear()


# supports_trtllm_attention


58
59
@patch("vllm.envs.VLLM_BATCH_INVARIANT", True)
def test_supports_batch_invariant_disables():
60
61
62
    assert supports_trtllm_attention() is False


63
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
64
@patch(
bai's avatar
bai committed
65
    "vllm.utils.flashinfer.current_platform.is_device_capability_family",
66
67
68
    return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
69
def test_supports_sm100_with_artifactory(_art, _cap):
70
71
72
    assert supports_trtllm_attention() is True


73
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
74
@patch(
75
    "vllm.utils.flashinfer.current_platform.is_device_capability",
76
77
    return_value=False,
)
78
def test_supports_non_sm100_platform(_cap):
79
80
81
    assert supports_trtllm_attention() is False


82
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
83
@patch(
84
    "vllm.utils.flashinfer.current_platform.is_device_capability",
85
86
87
    return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
88
def test_supports_sm100_without_artifactory(_art, _cap):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    assert supports_trtllm_attention() is False


# can_use_trtllm_attention


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False)
def test_can_use_force_disabled(_mock):
    cfg = get_config("Llama-3-70B")
    assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_can_use_compatible_heads(_sup, _force):
    cfg = get_config("Llama-3-70B")
    assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_can_use_incompatible_heads(_sup, _force):
    assert can_use_trtllm_attention(40, 6) is False


@pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys()))
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_can_use_platform_unsupported(_sup, _force, model):
    cfg = get_config(model)
    assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False


# use_trtllm_attention


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_force_off(_mock):
    assert _call(force_use_trtllm=False) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_dcp_fallback(_mock):
    assert _call(dcp_world_size=2) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_use_platform_unsupported(_mock):
    assert _call() is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_use_platform_unsupported_force_on_still_false(_mock):
    assert _call(force_use_trtllm=True) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads(_mock):
    assert _call(num_qo_heads=40, num_kv_heads=6) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads_force_on_still_false(_mock):
    assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_spec_decode_enables(_mock):
    assert _call(has_spec=True, is_prefill=False) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
@patch(
    "vllm.utils.flashinfer.current_platform.fp8_dtype",
    return_value=torch.float8_e4m3fn,
)
def test_use_fp8_query_forces_trtllm(_fp8, _sup):
    assert _call(q_dtype=torch.float8_e4m3fn) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_sinks_force_trtllm(_mock):
    assert _call(has_sinks=True) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_prefill_kv_auto(_mock):
    assert _call(is_prefill=True, kv_cache_dtype="auto") is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_prefill_kv_fp8(_mock):
    assert _call(is_prefill=True, kv_cache_dtype="fp8") is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_small_batch(_mock):
    assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_large_batch(_mock):
    assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_force_on(_mock):
    assert _call(force_use_trtllm=True) is True