"tests/kernels/quantization/test_triton_scaled_mm.py" did not exist on "cf069aa8aa38a9003c254f8434a29ec6a3070b08"
test_hybrid.py 10.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Mor Zusman's avatar
Mor Zusman committed
3
4
import pytest

5
from tests.utils import multi_gpu_test
6
from vllm.engine.arg_utils import EngineArgs
7
from vllm.sampling_params import SamplingParams
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
from ...utils import check_logprobs_close, check_outputs_equal

# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
    # TODO: Compare to a Mamba2 model. The HF transformers implementation of
    # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
    # See https://github.com/huggingface/transformers/pull/35943
    # "mistralai/Mamba-Codestral-7B-v0.1",
]
23

24
25
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
26
27
28
    # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
    # it is not yet available in huggingface transformers
    # "ibm-granite/granite-4.0-tiny-preview",
29
30
31
32
33
    # NOTE: Running Plamo2 in transformers implementation requires to install
    # causal-conv1d package, which is not listed as a test dependency as it's
    # not compatible with pip-compile.
    "pfnet/plamo-2-1b",
    "Zyphra/Zamba2-1.2B-instruct",
34
    "hmellor/tiny-random-BambaForCausalLM",
Shinichi Hemmi's avatar
Shinichi Hemmi committed
35
]
36
37
38

# Avoid OOM
MAX_NUM_SEQS = 4
Mor Zusman's avatar
Mor Zusman committed
39
40


41
42
43
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
44
45
46
47
48
49
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
50
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
51
) -> None:
52
53
54
    with hf_runner(model) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
55

56
57
58
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
59

60
61
62
63
64
65
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
Mor Zusman's avatar
Mor Zusman committed
66
67


68
69
70
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
71
72
73
74
75
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
76
    num_logprobs: int,
77
78
) -> None:
    for_loop_outputs = []
79
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
80
        for prompt in example_prompts:
81
82
83
84
            single_output, = vllm_model.generate_greedy_logprobs([prompt],
                                                                 max_tokens,
                                                                 num_logprobs)
            for_loop_outputs.append(single_output)
85

86
87
        batched_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
88

89
    check_logprobs_close(
90
91
92
93
94
95
96
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


97
98
99
100
101
102
103
104
105
106
107
108
109
110
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
    num_logprobs: int,
    chunked_prefill_token_size: int,
) -> None:
    max_num_seqs = chunked_prefill_token_size
    max_num_batched_tokens = chunked_prefill_token_size
111
112
113

    with vllm_runner(model,
                     enable_chunked_prefill=True,
114
115
116
117
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:
        chunked = vllm_model.generate_greedy_logprobs(example_prompts,
                                                      max_tokens, num_logprobs)
118

119
120
121
122
123
124
125
    with vllm_runner(model,
                     enable_chunked_prefill=False,
                     max_num_seqs=max_num_seqs) as vllm_model:
        non_chunked = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
126
127
128
129
130
131
132
        outputs_0_lst=chunked,
        outputs_1_lst=non_chunked,
        name_0="chunked",
        name_1="non_chunked",
    )


133
134
135
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
136
137
138
139
140
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """
    Tests chunked prefill in conjunction with n > 1. 
    
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
    sampling_params = SamplingParams(n=3,
                                     temperature=1,
                                     seed=0,
                                     max_tokens=max_tokens)
    with vllm_runner(
            model,
            enable_chunked_prefill=True,
            # forces prefill chunks with decoding
            max_num_batched_tokens=MAX_NUM_SEQS * 3,
            max_num_seqs=MAX_NUM_SEQS,
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
163
164


165
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
166
167
168
169
170
171
172
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
173
174
175
176
177
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
Shinichi Hemmi's avatar
Shinichi Hemmi committed
178
179
    vllm_config = EngineArgs(model=model,
                             trust_remote_code=True).create_engine_config()
180
    while len(example_prompts) == vllm_config.pad_for_cudagraph(
181
            len(example_prompts)):
182
183
184
        example_prompts.append(example_prompts[0])

    try:
185
        with vllm_runner(model) as vllm_model:
186
187
188
189
190
191
192
193
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
            "Could be related to mamba cache not padded correctly")


194
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
195
196
197
198
199
200
201
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
202
203
204
205
206
207
    """
    Tests that outputs are identical with and w/o preemptions (recompute).
    """
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
        scheduler = vllm_model.model.llm_engine.scheduler[0]
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
208
209
210
        preempt_vllm_outputs = vllm_model.generate_greedy(
            example_prompts, max_tokens)

211
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
212
213
214
215
216
217
218
219
220
221
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=preempt_vllm_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="vllm_preepmtions",
        name_1="vllm",
    )


222
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
223
224
225
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
226
    model: str,
227
) -> None:
228
229
230
231
232
233
234
235
236
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
    statelessness mechanism where it can cleanup new incoming requests in
    a single step.
    """
237
    try:
238
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
239
240
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
241
        pytest.fail("Hybrid inner state wasn't cleaned up properly between"
242
243
244
                    "steps finished requests registered unnecessarily ")


245
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
246
247
248
def test_state_cleanup(
    vllm_runner,
    example_prompts,
249
    model: str,
Mor Zusman's avatar
Mor Zusman committed
250
) -> None:
251
252
253
254
255
256
    """ 
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
    
    If its not cleaned, an error would be expected.
    """
Mor Zusman's avatar
Mor Zusman committed
257
    try:
258
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
259
260
261
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
262
        pytest.fail("Hybrid inner state wasn't cleaned up between states, "
Mor Zusman's avatar
Mor Zusman committed
263
264
265
                    "could be related to finished_requests_ids")


266
267
268
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(
269
270
    vllm_runner,
    example_prompts,
271
272
    model: str,
    max_tokens: int,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
) -> None:
    with vllm_runner(model, num_scheduler_steps=8,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_multistep = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    with vllm_runner(model, num_scheduler_steps=1,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_single_step = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=vllm_outputs_multistep,
        outputs_1_lst=vllm_outputs_single_step,
        name_0="vllm_outputs_multistep",
        name_1="vllm_outputs_single_step",
    )


292
@multi_gpu_test(num_gpus=2)
293
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
294
@pytest.mark.parametrize("max_tokens", [64])
295
296
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
297
298
299
300
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
301
    num_logprobs: int,
302
) -> None:
303
    with vllm_runner(model, tensor_parallel_size=1,
304
                     max_num_seqs=2) as vllm_model:
305
306
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
307

308
    with vllm_runner(model, tensor_parallel_size=2,
309
                     max_num_seqs=2) as vllm_model:
310
311
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
312

313
    check_logprobs_close(
314
315
316
317
318
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )