test_mamba.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.

Run `pytest tests/models/test_mamba.py`.
"""
import pytest
7
import torch
8
9
from transformers import AutoModelForCausalLM, AutoTokenizer

10
from vllm.engine.arg_utils import EngineArgs
11
12
13
14
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal

15
16
17
18
19
20
21
22
MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
    # TODO: Compare to a Mamba2 model. The HF transformers implementation of
    # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
    # See https://github.com/huggingface/transformers/pull/35943
    # "mistralai/Mamba-Codestral-7B-v0.1",
]
23
24
25
26
27
28
29
30
31


# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
    # Create a text generation pipeline
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

32
33
34
35
    # Set the device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

36
37
38
39
40
41
42
43
    # Generate texts from the prompts
    outputs = []
    for prompt in example_prompts:
        # Tokenize the input prompt with truncation
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
        input_ids = inputs["input_ids"].to(model.device)

        # Generate text using the model's generate method directly
44
45
46
        generated_ids = model.generate(input_ids,
                                       max_new_tokens=max_tokens,
                                       do_sample=False)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        generated_text = tokenizer.decode(generated_ids[0],
                                          skip_special_tokens=True)

        outputs.append((generated_ids[0].tolist(), generated_text))

    return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    hf_outputs = generate_greedy(model, example_prompts, max_tokens)

67
68
    # Set max_num_seqs to keep Codestral from going OOM at fp32
    with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
69
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    for i in range(len(example_prompts)):
        hf_output_ids, hf_output_str = hf_outputs[i]
        vllm_output_ids, vllm_output_str = vllm_outputs[i]
        assert hf_output_str == vllm_output_str, (
            f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
        assert hf_output_ids == vllm_output_ids, (
            f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    # To pass the small model tests, we need full precision.
    for_loop_outputs = []
92
    with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        for prompt in example_prompts:
            for_loop_outputs.append(
                vllm_model.generate_greedy([prompt], max_tokens)[0])

        batched_outputs = vllm_model.generate_greedy(example_prompts,
                                                     max_tokens)

    check_outputs_equal(
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
                                                model: str, dtype: str,
                                                max_tokens: int) -> None:
    # Tests chunked prefill in conjunction with n>1. In this case, prefill is
    # populated with decoding tokens and we test that it doesn't fail.
    # This test might fail if cache is not allocated correctly for n > 1
    # decoding steps inside a chunked prefill forward pass (where we have both
    # prefill and decode together )
    sampling_params = SamplingParams(n=3,
                                     temperature=1,
                                     seed=0,
                                     max_tokens=max_tokens)
    with vllm_runner(
            model,
            dtype=dtype,
            enable_chunked_prefill=True,
            max_num_batched_tokens=30,
            max_num_seqs=10  # forces prefill chunks with decoding
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
                         max_tokens: int,
                         chunked_prefill_token_size: int) -> None:
    """
    Checks exact match decode between huggingface model and vllm runner with
    chunked prefill.
    """
    max_num_seqs = chunked_prefill_token_size
    max_num_batched_tokens = chunked_prefill_token_size

    non_chunked = generate_greedy(model, example_prompts, max_tokens)

    with vllm_runner(model,
                     dtype=dtype,
                     enable_chunked_prefill=True,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:
        chunked = vllm_model.generate_greedy(example_prompts,
                                             max_tokens=max_tokens)

    check_outputs_equal(
        outputs_0_lst=chunked,
        outputs_1_lst=non_chunked,
        name_0="chunked",
        name_1="non_chunked",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:

176
177
178
179
180
181
182
    # Numerical differences produce slightly different output for these
    if 'state-spaces' in model:
        example_prompts.pop(0)
        example_prompts.pop(0)
        example_prompts.pop(0)

    with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
183
184
185
        for_loop_outputs = []
        for _ in range(10):
            for_loop_outputs.append(
186
                vllm_model.generate_greedy(example_prompts, max_tokens)[0])
187
188
189
190
        sampling_params = SamplingParams(n=10,
                                         temperature=0.001,
                                         seed=0,
                                         max_tokens=max_tokens)
191
        n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    token_ids, texts = n_lt_1_outputs[0]
    n_lt_1_outputs = [(token_id, text)
                      for token_id, text in zip(token_ids, texts)]

    check_outputs_equal(
        outputs_0_lst=n_lt_1_outputs,
        outputs_1_lst=for_loop_outputs,
        name_0="vllm_n_lt_1_outputs",
        name_1="vllm",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    # This test is for verifying that mamba cache is padded to CG captured
    # batch size. If it's not, a torch RuntimeError will be raised because
    # tensor dimensions aren't compatible
217
218
    vllm_config = EngineArgs(model=model).create_engine_config()
    while len(example_prompts) == vllm_config.pad_for_cudagraph(
219
            len(example_prompts)):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        example_prompts.append(example_prompts[0])

    try:
        with vllm_runner(model, dtype=dtype) as vllm_model:
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
            "Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    # Tests that outputs are identical with and w/o preemtions (recompute)
    assert dtype == "float"

245
    with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        vllm_model.model.llm_engine.scheduler[
            0].ENABLE_ARTIFICIAL_PREEMPT = True
        preempt_vllm_outputs = vllm_model.generate_greedy(
            example_prompts, max_tokens)

        vllm_model.model.llm_engine.scheduler[
            0].ENABLE_ARTIFICIAL_PREEMPT = False
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=preempt_vllm_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="vllm_preepmtions",
        name_1="vllm",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    model: str,
    dtype: str,
    example_prompts,
) -> None:
    # This test is for verifying that the Mamba inner state management doesn't
    # collapse in case where the number of incoming requests and
    # finished_requests_ids is larger than the maximum Mamba block capacity.
    # This could generally happen due to the fact that Mamba does support
    # statelessness mechanism where it can cleanup new incoming requests in
    # a single step.
    try:
        with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
        pytest.fail("Mamba inner state wasn't cleaned up properly between"
                    "steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
    vllm_runner,
    model: str,
    dtype: str,
    example_prompts,
) -> None:
    # This test is for verifying that the Mamba state is cleaned up between
    # steps, If its not cleaned, an error would be expected.
    try:
296
        with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
297
298
299
300
301
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
        pytest.fail("Mamba inner state wasn't cleaned up between states, "
                    "could be related to finished_requests_ids")
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
    vllm_runner,
    model: str,
    dtype: str,
    example_prompts,
) -> None:
    with vllm_runner(model, num_scheduler_steps=8,
                     max_num_seqs=2) as vllm_model:
        vllm_model.generate_greedy([example_prompts[0]] * 10, 1)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
                               max_tokens: int, example_prompts) -> None:
    with vllm_runner(model, num_scheduler_steps=8,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_multistep = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    with vllm_runner(model, num_scheduler_steps=1,
                     max_num_seqs=2) as vllm_model:
        vllm_outputs_single_step = vllm_model.generate_greedy(
            example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=vllm_outputs_multistep,
        outputs_1_lst=vllm_outputs_single_step,
        name_0="vllm_outputs_multistep",
        name_1="vllm_outputs_single_step",
    )