test_hybrid.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Mor Zusman's avatar
Mor Zusman committed
4
5
import pytest

6
from tests.models.registry import HF_EXAMPLE_MODELS
7
from tests.utils import multi_gpu_test
8
from vllm.engine.arg_utils import EngineArgs
9
from vllm.sampling_params import SamplingParams
10

11
12
from ...utils import check_logprobs_close, check_outputs_equal

13
14
15
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model

16
17
18
19
20
21
22
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
23
    "yujiepan/mamba2-codestral-v0.1-tiny-random",
24
]
25

26
27
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
28
29
    # skipping until vLLM implementation issues are resolved
    # "pfnet/plamo-2-1b",
30
    "Zyphra/Zamba2-1.2B-instruct",
31
    "hmellor/tiny-random-BambaForCausalLM",
32
33
34
35
36
37
38
39
40
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
]

HF_UNSUPPORTED_MODELS = [
    # The HF transformers implementation of
    # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
    # doesn't compare vLLM output with HF output.
    # See https://github.com/huggingface/transformers/pull/35943
41
42
43
44
    "yujiepan/mamba2-codestral-v0.1-tiny-random",
    # transformers 4.55 is still producing garbage for this model
    # TODO(tdoublep): follow-up on transformers side
    "ibm-granite/granite-4.0-tiny-preview"
Shinichi Hemmi's avatar
Shinichi Hemmi committed
45
]
46

Chen Zhang's avatar
Chen Zhang committed
47
V1_SUPPORTED_MODELS = [
48
49
    "state-spaces/mamba-130m-hf",
    "ai21labs/Jamba-tiny-dev",
50
    "yujiepan/mamba2-codestral-v0.1-tiny-random",
51
    "Zyphra/Zamba2-1.2B-instruct",
52
    "hmellor/tiny-random-BambaForCausalLM",
53
54
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
Chen Zhang's avatar
Chen Zhang committed
55
56
]

57
58
# Avoid OOM
MAX_NUM_SEQS = 4
Mor Zusman's avatar
Mor Zusman committed
59
60


61
62
63
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
64
65
66
67
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
Chen Zhang's avatar
Chen Zhang committed
68
    monkeypatch,
Mor Zusman's avatar
Mor Zusman committed
69
70
    model: str,
    max_tokens: int,
71
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
72
) -> None:
73
74
75
76

    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
77
78
        hf_version_check = model_info.check_transformers_version(
            on_fail="return")
79
    except ValueError:
80
81
82
83
        hf_version_check = None

    if hf_version_check is not None:
        print(f"Skipping transformers comparison because: {hf_version_check}")
84

85
    with hf_runner(model) as hf_model:
86
        if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None:
Chen Zhang's avatar
Chen Zhang committed
87
88
89
90
            hf_outputs = hf_model.generate_greedy_logprobs_limit(
                example_prompts, max_tokens, num_logprobs)
        else:
            hf_outputs = None
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
91

92
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Chen Zhang's avatar
Chen Zhang committed
93
        vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
94
            example_prompts, max_tokens, num_logprobs)
95

Chen Zhang's avatar
Chen Zhang committed
96
97
98
    if model in V1_SUPPORTED_MODELS:
        with monkeypatch.context() as m:
            m.setenv("VLLM_USE_V1", "1")
99
100
101
            if model in HYBRID_MODELS:
                # required due to reorder_batch behaviour
                m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
Chen Zhang's avatar
Chen Zhang committed
102
103
            with vllm_runner(model,
                             max_num_seqs=MAX_NUM_SEQS,
104
                             enable_prefix_caching=False) as vllm_model:
Chen Zhang's avatar
Chen Zhang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
                vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
                    example_prompts, max_tokens, num_logprobs)
    else:
        vllm_v1_outputs = None

    if hf_outputs is not None:
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_v0_outputs,
            name_0="hf",
            name_1="vllm-v0",
        )

    if model in V1_SUPPORTED_MODELS:
        ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
        check_logprobs_close(
            outputs_0_lst=ref_outputs,
            outputs_1_lst=vllm_v1_outputs,
            name_0="hf" if hf_outputs is not None else "vllm-v0",
            name_1="vllm-v1",
        )
Mor Zusman's avatar
Mor Zusman committed
126
127


128
129
130
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
131
132
133
134
135
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
136
    num_logprobs: int,
137
) -> None:
138
139
140
141
142
143
144
145

    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

146
    for_loop_outputs = []
147
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
148
        for prompt in example_prompts:
149
150
151
152
            single_output, = vllm_model.generate_greedy_logprobs([prompt],
                                                                 max_tokens,
                                                                 num_logprobs)
            for_loop_outputs.append(single_output)
153

154
155
        batched_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
156

157
    check_logprobs_close(
158
159
160
161
162
163
164
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


165
166
167
168
169
170
171
172
173
174
175
176
177
178
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
    num_logprobs: int,
    chunked_prefill_token_size: int,
) -> None:
    max_num_seqs = chunked_prefill_token_size
    max_num_batched_tokens = chunked_prefill_token_size
179
180
181

    with vllm_runner(model,
                     enable_chunked_prefill=True,
182
183
184
185
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:
        chunked = vllm_model.generate_greedy_logprobs(example_prompts,
                                                      max_tokens, num_logprobs)
186

187
188
189
190
191
192
193
    with vllm_runner(model,
                     enable_chunked_prefill=False,
                     max_num_seqs=max_num_seqs) as vllm_model:
        non_chunked = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
194
195
196
197
198
199
200
        outputs_0_lst=chunked,
        outputs_1_lst=non_chunked,
        name_0="chunked",
        name_1="non_chunked",
    )


201
202
203
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
204
205
206
207
208
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    """
    Tests chunked prefill in conjunction with n > 1. 
    
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
    sampling_params = SamplingParams(n=3,
                                     temperature=1,
                                     seed=0,
                                     max_tokens=max_tokens)
    with vllm_runner(
            model,
            enable_chunked_prefill=True,
            # forces prefill chunks with decoding
            max_num_batched_tokens=MAX_NUM_SEQS * 3,
            max_num_seqs=MAX_NUM_SEQS,
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
231
232


233
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
234
235
236
237
238
239
240
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
241
242
243
244
245
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
Shinichi Hemmi's avatar
Shinichi Hemmi committed
246
247
    vllm_config = EngineArgs(model=model,
                             trust_remote_code=True).create_engine_config()
248
    while len(example_prompts) == vllm_config.pad_for_cudagraph(
249
            len(example_prompts)):
250
251
252
        example_prompts.append(example_prompts[0])

    try:
253
        with vllm_runner(model) as vllm_model:
254
255
256
257
258
259
260
261
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
            "Could be related to mamba cache not padded correctly")


262
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
263
264
265
266
267
268
269
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
270
271
272
273
    """
    Tests that outputs are identical with and w/o preemptions (recompute).
    """
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
274
        scheduler = vllm_model.llm.llm_engine.scheduler[0]
275
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
276
277
278
        preempt_vllm_outputs = vllm_model.generate_greedy(
            example_prompts, max_tokens)

279
        scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
280
281
282
283
284
285
286
287
288
289
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=preempt_vllm_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="vllm_preepmtions",
        name_1="vllm",
    )


290
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
291
292
293
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
294
    model: str,
295
) -> None:
296
297
298
299
300
301
302
303
304
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
    statelessness mechanism where it can cleanup new incoming requests in
    a single step.
    """
305
    try:
306
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
307
308
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
309
        pytest.fail("Hybrid inner state wasn't cleaned up properly between"
310
311
312
                    "steps finished requests registered unnecessarily ")


313
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
314
315
316
def test_state_cleanup(
    vllm_runner,
    example_prompts,
317
    model: str,
Mor Zusman's avatar
Mor Zusman committed
318
) -> None:
319
320
321
322
323
324
    """ 
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
    
    If its not cleaned, an error would be expected.
    """
Mor Zusman's avatar
Mor Zusman committed
325
    try:
326
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
327
328
329
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
330
        pytest.fail("Hybrid inner state wasn't cleaned up between states, "
Mor Zusman's avatar
Mor Zusman committed
331
332
333
                    "could be related to finished_requests_ids")


334
@multi_gpu_test(num_gpus=2)
335
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
336
@pytest.mark.parametrize("max_tokens", [64])
337
338
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
339
340
341
342
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
343
    num_logprobs: int,
344
) -> None:
345
    with vllm_runner(model, tensor_parallel_size=1,
346
                     max_num_seqs=2) as vllm_model:
347
348
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
349

350
    with vllm_runner(model, tensor_parallel_size=2,
351
                     max_num_seqs=2) as vllm_model:
352
353
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
354

355
    check_logprobs_close(
356
357
358
359
360
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420


@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_full_cuda_graph(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:

    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
        if model not in HF_UNSUPPORTED_MODELS:
            hf_outputs = hf_model.generate_greedy_logprobs_limit(
                example_prompts, max_tokens, num_logprobs)
        else:
            hf_outputs = None

    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
        vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        if model in HYBRID_MODELS:
            # required due to reorder_batch behaviour
            m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
        with vllm_runner(model,
                         max_num_seqs=MAX_NUM_SEQS,
                         compilation_config={'full_cuda_graph': True},
                         enable_prefix_caching=False) as vllm_model:
            vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
                example_prompts, max_tokens, num_logprobs)

    if hf_outputs is not None:
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_v0_outputs,
            name_0="hf",
            name_1="vllm-v0",
        )

    ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
    check_logprobs_close(
        outputs_0_lst=ref_outputs,
        outputs_1_lst=vllm_v1_outputs,
        name_0="hf" if hf_outputs is not None else "vllm-v0",
        name_1="vllm-v1",
    )