test_hybrid.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Mor Zusman's avatar
Mor Zusman committed
4
5
import pytest

6
from tests.models.registry import HF_EXAMPLE_MODELS
7
from tests.utils import multi_gpu_test
8
from vllm.engine.arg_utils import EngineArgs
9
from vllm.sampling_params import SamplingParams
10

11
12
from ...utils import check_logprobs_close, check_outputs_equal

13
14
15
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model

16
17
18
19
20
21
22
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
Chen Zhang's avatar
Chen Zhang committed
23
    "mistralai/Mamba-Codestral-7B-v0.1",
24
]
25

26
27
28
29
30
31
32
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
    # NOTE: Running Plamo2 in transformers implementation requires to install
    # causal-conv1d package, which is not listed as a test dependency as it's
    # not compatible with pip-compile.
    "pfnet/plamo-2-1b",
    "Zyphra/Zamba2-1.2B-instruct",
33
    "hmellor/tiny-random-BambaForCausalLM",
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    "ibm-ai-platform/Bamba-9B-v1",
    "nvidia/Nemotron-H-8B-Base-8K",
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
]

HF_UNSUPPORTED_MODELS = [
    # The HF transformers implementation of
    # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
    # doesn't compare vLLM output with HF output.
    # See https://github.com/huggingface/transformers/pull/35943
    "mistralai/Mamba-Codestral-7B-v0.1",
    # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
    # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
    "nvidia/Nemotron-H-8B-Base-8K",
    # NOTE: Currently the test fails due to HF transformers issue fixed in:
    # https://github.com/huggingface/transformers/pull/39033
    # We will enable vLLM test for Granite after next HF transformers release.
    "ibm-granite/granite-4.0-tiny-preview",
Shinichi Hemmi's avatar
Shinichi Hemmi committed
53
]
54

Chen Zhang's avatar
Chen Zhang committed
55
V1_SUPPORTED_MODELS = [
56
57
    "state-spaces/mamba-130m-hf",
    "ai21labs/Jamba-tiny-dev",
Chen Zhang's avatar
Chen Zhang committed
58
    "mistralai/Mamba-Codestral-7B-v0.1",
59
60
61
62
63
    "ibm-ai-platform/Bamba-9B-v1",
    "Zyphra/Zamba2-1.2B-instruct",
    "nvidia/Nemotron-H-8B-Base-8K",
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
Chen Zhang's avatar
Chen Zhang committed
64
65
]

66
67
# Avoid OOM
MAX_NUM_SEQS = 4
Mor Zusman's avatar
Mor Zusman committed
68
69


70
71
72
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
73
74
75
76
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
Chen Zhang's avatar
Chen Zhang committed
77
    monkeypatch,
Mor Zusman's avatar
Mor Zusman committed
78
79
    model: str,
    max_tokens: int,
80
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
81
) -> None:
82
83
84
85
86
87
88
89

    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

90
    with hf_runner(model) as hf_model:
91
        if model not in HF_UNSUPPORTED_MODELS:
Chen Zhang's avatar
Chen Zhang committed
92
93
94
95
            hf_outputs = hf_model.generate_greedy_logprobs_limit(
                example_prompts, max_tokens, num_logprobs)
        else:
            hf_outputs = None
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
96

97
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Chen Zhang's avatar
Chen Zhang committed
98
        vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
99
            example_prompts, max_tokens, num_logprobs)
100

Chen Zhang's avatar
Chen Zhang committed
101
102
103
    if model in V1_SUPPORTED_MODELS:
        with monkeypatch.context() as m:
            m.setenv("VLLM_USE_V1", "1")
104
105
106
            if model in HYBRID_MODELS:
                # required due to reorder_batch behaviour
                m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
Chen Zhang's avatar
Chen Zhang committed
107
108
            with vllm_runner(model,
                             max_num_seqs=MAX_NUM_SEQS,
109
                             enable_prefix_caching=False) as vllm_model:
Chen Zhang's avatar
Chen Zhang committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
                    example_prompts, max_tokens, num_logprobs)
    else:
        vllm_v1_outputs = None

    if hf_outputs is not None:
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_v0_outputs,
            name_0="hf",
            name_1="vllm-v0",
        )

    if model in V1_SUPPORTED_MODELS:
        ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
        check_logprobs_close(
            outputs_0_lst=ref_outputs,
            outputs_1_lst=vllm_v1_outputs,
            name_0="hf" if hf_outputs is not None else "vllm-v0",
            name_1="vllm-v1",
        )
Mor Zusman's avatar
Mor Zusman committed
131
132


133
134
135
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
136
137
138
139
140
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
141
    num_logprobs: int,
142
) -> None:
143
144
145
146
147
148
149
150

    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

151
    for_loop_outputs = []
152
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
153
        for prompt in example_prompts:
154
155
156
157
            single_output, = vllm_model.generate_greedy_logprobs([prompt],
                                                                 max_tokens,
                                                                 num_logprobs)
            for_loop_outputs.append(single_output)
158

159
160
        batched_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
161

162
    check_logprobs_close(
163
164
165
166
167
168
169
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


170
171
172
173
174
175
176
177
178
179
180
181
182
183
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
    num_logprobs: int,
    chunked_prefill_token_size: int,
) -> None:
    max_num_seqs = chunked_prefill_token_size
    max_num_batched_tokens = chunked_prefill_token_size
184
185
186

    with vllm_runner(model,
                     enable_chunked_prefill=True,
187
188
189
190
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:
        chunked = vllm_model.generate_greedy_logprobs(example_prompts,
                                                      max_tokens, num_logprobs)
191

192
193
194
195
196
197
198
    with vllm_runner(model,
                     enable_chunked_prefill=False,
                     max_num_seqs=max_num_seqs) as vllm_model:
        non_chunked = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
199
200
201
202
203
204
205
        outputs_0_lst=chunked,
        outputs_1_lst=non_chunked,
        name_0="chunked",
        name_1="non_chunked",
    )


206
207
208
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
209
210
211
212
213
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    """
    Tests chunked prefill in conjunction with n > 1. 
    
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
    sampling_params = SamplingParams(n=3,
                                     temperature=1,
                                     seed=0,
                                     max_tokens=max_tokens)
    with vllm_runner(
            model,
            enable_chunked_prefill=True,
            # forces prefill chunks with decoding
            max_num_batched_tokens=MAX_NUM_SEQS * 3,
            max_num_seqs=MAX_NUM_SEQS,
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
236
237


238
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
239
240
241
242
243
244
245
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
246
247
248
249
250
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
Shinichi Hemmi's avatar
Shinichi Hemmi committed
251
252
    vllm_config = EngineArgs(model=model,
                             trust_remote_code=True).create_engine_config()
253
    while len(example_prompts) == vllm_config.pad_for_cudagraph(
254
            len(example_prompts)):
255
256
257
        example_prompts.append(example_prompts[0])

    try:
258
        with vllm_runner(model) as vllm_model:
259
260
261
262
263
264
265
266
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
            "Could be related to mamba cache not padded correctly")


267
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
268
269
270
271
272
273
274
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
275
276
277
278
    """
    Tests that outputs are identical with and w/o preemptions (recompute).
    """
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
279
        scheduler = vllm_model.llm.llm_engine.scheduler[0]
280
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
281
282
283
        preempt_vllm_outputs = vllm_model.generate_greedy(
            example_prompts, max_tokens)

284
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
285
286
287
288
289
290
291
292
293
294
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=preempt_vllm_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="vllm_preepmtions",
        name_1="vllm",
    )


295
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
296
297
298
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
299
    model: str,
300
) -> None:
301
302
303
304
305
306
307
308
309
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
    statelessness mechanism where it can cleanup new incoming requests in
    a single step.
    """
310
    try:
311
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
312
313
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
314
        pytest.fail("Hybrid inner state wasn't cleaned up properly between"
315
316
317
                    "steps finished requests registered unnecessarily ")


318
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
319
320
321
def test_state_cleanup(
    vllm_runner,
    example_prompts,
322
    model: str,
Mor Zusman's avatar
Mor Zusman committed
323
) -> None:
324
325
326
327
328
329
    """ 
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
    
    If its not cleaned, an error would be expected.
    """
Mor Zusman's avatar
Mor Zusman committed
330
    try:
331
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
332
333
334
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
335
        pytest.fail("Hybrid inner state wasn't cleaned up between states, "
Mor Zusman's avatar
Mor Zusman committed
336
337
338
                    "could be related to finished_requests_ids")


339
340
341
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(
342
343
    vllm_runner,
    example_prompts,
344
345
    model: str,
    max_tokens: int,
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
) -> None:
    with vllm_runner(model, num_scheduler_steps=8,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_multistep = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    with vllm_runner(model, num_scheduler_steps=1,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_single_step = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=vllm_outputs_multistep,
        outputs_1_lst=vllm_outputs_single_step,
        name_0="vllm_outputs_multistep",
        name_1="vllm_outputs_single_step",
    )


365
@multi_gpu_test(num_gpus=2)
366
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
367
@pytest.mark.parametrize("max_tokens", [64])
368
369
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
370
371
372
373
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
374
    num_logprobs: int,
375
) -> None:
376
    with vllm_runner(model, tensor_parallel_size=1,
377
                     max_num_seqs=2) as vllm_model:
378
379
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
380

381
    with vllm_runner(model, tensor_parallel_size=2,
382
                     max_num_seqs=2) as vllm_model:
383
384
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
385

386
    check_logprobs_close(
387
388
389
390
391
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )