test_torchao.py 14.1 KB
Newer Older
Driss Guessous's avatar
Driss Guessous committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Driss Guessous's avatar
Driss Guessous committed
3
4
5
import importlib.util

import pytest
6
import torch
Driss Guessous's avatar
Driss Guessous committed
7

8
from vllm.model_executor.model_loader import get_model_loader
9
10
from vllm.platforms import current_platform

11
DEVICE_TYPE = current_platform.device_type
Driss Guessous's avatar
Driss Guessous committed
12
13
14
15
16
DTYPE = ["bfloat16"]

TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None


17
18
19
20
@pytest.mark.skipif(
    current_platform.is_rocm() and current_platform.is_fp8_fnuz(),
    reason="Only fp8_fnuz supported on CDNA3 architecture",
)
Driss Guessous's avatar
Driss Guessous committed
21
22
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
23
    with vllm_runner(
24
        "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.15.0",
25
26
27
28
        quantization="torchao",
        dtype="bfloat16",
        enforce_eager=True,
    ) as llm:
29
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
Driss Guessous's avatar
Driss Guessous committed
30
31
32
    assert output


33
34
35
36
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.parametrize(
    "pt_load_map_location",
    [
37
        f"{DEVICE_TYPE}:0",
38
        # {"": "cuda"},
39
40
41
    ],
)
def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_location):
42
    torch._dynamo.reset()
43
    model_name = "jerryzh168/opt-125m-int8wo-partial-quant"
44
45
46
47
48
    with vllm_runner(
        model_name=model_name,
        quantization="torchao",
        dtype="bfloat16",
        pt_load_map_location=pt_load_map_location,
49
        enforce_eager=True,
50
    ) as llm:
51
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
52
53
54
55

        assert output


56
57
58
59
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
    torch._dynamo.reset()
    model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
60
61
62
63
    with vllm_runner(
        model_name=model_name,
        quantization="torchao",
        dtype="bfloat16",
64
        pt_load_map_location=f"{DEVICE_TYPE}:0",
65
        enforce_eager=True,
66
    ) as llm:
67
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
68
69
70
71

        assert output


72
73
74
75
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
    reason="since torchao nightly is only compatible with torch nightly"
    "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
76
77
    "torchao tests that requires newer versions (0.14.0.dev+) for now"
)
78
79
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
    torch._dynamo.reset()
80
81
82
83
84
    model_name = "torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2-0.14.0.dev"
    with vllm_runner(
        model_name=model_name,
        quantization="torchao",
        dtype="bfloat16",
85
        pt_load_map_location=f"{DEVICE_TYPE}:0",
86
    ) as llm:
87
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
88
89

        assert output
90
91
92


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
93
94
def test_online_quant_config_dict_json(vllm_runner, enable_pickle):
    """Testing online quantization, load_weights integration point,
95
96
97
98
99
100
101
102
    with config dict serialized to json string
    """
    torch._dynamo.reset()
    model_name = "facebook/opt-125m"

    import json

    from torchao.core.config import config_to_dict
103
    from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
104
105

    torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
106
107
        granularity=PerRow()
    )
108
    hf_overrides = {
109
110
111
        "quantization_config_dict_json": json.dumps(
            config_to_dict(torchao_quant_config)
        )
112
    }
113
114
115
    with vllm_runner(
        model_name=model_name,
        dtype="bfloat16",
116
        pt_load_map_location=f"{DEVICE_TYPE}:0",
117
118
        quantization="torchao",
        hf_overrides=hf_overrides,
119
        enforce_eager=True,
120
    ) as llm:
121
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
122

123
124
125
126
127
128
129
130
131
132
133
134
        load_config = llm.llm.llm_engine.vllm_config.load_config
        model_config = llm.llm.llm_engine.vllm_config.model_config

        def load_weights(model):
            model_loader = get_model_loader(load_config)
            weights_iterator = model_loader.get_all_weights(model_config, model)
            model.load_weights(weights_iterator)

        llm.apply_model(load_weights)

        reload_output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
        assert output[0][0] == reload_output[0][0]
135
136
137


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
138
def test_online_quant_config_file(vllm_runner):
139
140
141
142
143
144
145
146
147
    """Testing on the fly quantization, load_weights integration point,
    with config file
    """
    torch._dynamo.reset()
    model_name = "facebook/opt-125m"
    import json
    from tempfile import NamedTemporaryFile

    from torchao.core.config import config_to_dict
148
    from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
149
150
151
152
153
154
155
156
157
158

    config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())

    with NamedTemporaryFile(mode="w", delete=False) as f:
        f.write(json.dumps(config_to_dict(config)))
        # close the file to save it
        f.close()
        config_file_name = str(f.name)

        hf_overrides = {"quantization_config_file": config_file_name}
159
160
161
        with vllm_runner(
            model_name=model_name,
            dtype="bfloat16",
162
            pt_load_map_location=f"{DEVICE_TYPE}:0",
163
164
            quantization="torchao",
            hf_overrides=hf_overrides,
165
            enforce_eager=True,
166
        ) as llm:
167
            output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
168
169
170
171
172
173
174
175
176

            assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_reload_weights():
    import json

    from torchao.core.config import config_to_dict
177
    from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
178
179
180
181

    from vllm import LLM, SamplingParams

    torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
182
183
        granularity=PerRow()
    )
184
185

    hf_overrides = {
186
187
188
        "quantization_config_dict_json": json.dumps(
            config_to_dict(torchao_quant_config)
        )
189
190
191
192
193
194
195
196
197
198
199
    }

    llm = LLM(
        model="Qwen/Qwen3-0.6B",
        dtype="bfloat16",
        load_format="dummy",
        enforce_eager=True,
        quantization="torchao",
        hf_overrides=hf_overrides,
    )
    # Update load format from `dummy` to `auto`
200
201
202
    llm.collective_rpc(
        "update_config", args=({"load_config": {"load_format": "auto"}},)
    )
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    # Now reload real weights inplace
    llm.collective_rpc("reload_weights")
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0, top_p=0.95)
    outputs = llm.generate(prompts, sampling_params)
    # make sure it runs
    for output in outputs:
        generated_text = output.outputs[0].text
        assert generated_text
        # can also uncomment locally to make sure the generated
        # output makes sense
        # prompt = output.prompt
        # print(f"Prompt:    {prompt!r}")
        # print(f"Output:    {generated_text!r}")
        # print("-" * 60)
224
225


226
227
228
229
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
    reason="since torchao nightly is only compatible with torch nightly"
    "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
230
    "torchao tests that requires newer versions (0.15.0.dev+) for now"
231
)
232
def test_safetensors_model_loading_with_params(vllm_runner):
233
    torch._dynamo.reset()
234
235
    # using this model to test safetensors loading with file sharding
    model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
236
    with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
237
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
238
239
240
241

        assert output


242
243
244
245
246
247
248
249
250
251
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
    reason="since torchao nightly is only compatible with torch nightly"
    "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
    "torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
    torch._dynamo.reset()
    model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev"
    with vllm_runner(
252
        model_name=model_name, dtype="bfloat16", pt_load_map_location=f"{DEVICE_TYPE}:0"
253
    ) as llm:
254
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
    reason="since torchao nightly is only compatible with torch nightly"
    "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
    "torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel(vllm_runner, monkeypatch):
    """We load a model with Int4Tensor (plain format) linear weights
    and verify that the weight is updated to Int4PreshuffledTensor
    after loading in vllm
    """
    from torchao.quantization import Int4PreshuffledTensor
    from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90

    torch._dynamo.reset()
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
    model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
    # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
    # have meta kernel implemented yet, can remove this flag after that is implemented
    with vllm_runner(
        model_name=model_name,
        quantization="torchao",
        dtype="bfloat16",
282
        pt_load_map_location=f"{DEVICE_TYPE}:0",
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        enforce_eager=True,
    ) as llm:

        def has_int4_preshuffled_tensor_weight(model):
            return isinstance(
                model.model.decoder.layers[0].self_attn.qkv_proj.weight,
                Int4PreshuffledTensor,
            )

        def get_weight_attrs(model):
            weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
            return [
                weight.requires_grad,
                weight.input_dim,
                weight.output_dim,
                hasattr(weight, "weight_loader"),
            ]

        llm_engine = llm.get_llm().llm_engine
        has_int4_preshuffled_tensor = any(
            llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
        )
        weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]

        # making sure we are using Int4PreshuffledTensor on H100 GPU, when
        # fbgemm_gpu_genai
        # library is installed, otherwise it should be using Int4Tensor
        if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
            assert has_int4_preshuffled_tensor
        else:
            assert not has_int4_preshuffled_tensor

        assert weight_attrs == [False, 1, 0, True]
        output = llm.generate_greedy(["The capital of France is"], max_tokens=32)

        assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
    reason="since torchao nightly is only compatible with torch nightly"
    "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
    "torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
    vllm_runner, monkeypatch
):
    """We load a bf16 model and online quantize the model to int4, then verify that
    the weights are updated to Int4PreshuffledTensor after online quantization
    """
    from torchao.quantization import Int4PreshuffledTensor
    from torchao.utils import _is_fbgemm_gpu_genai_available, is_sm_at_least_90

    torch._dynamo.reset()
    model_name = "facebook/opt-125m"

    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    import json

    from torchao.core.config import config_to_dict
    from torchao.quantization import Int4WeightOnlyConfig

    torchao_quant_config = Int4WeightOnlyConfig(
        group_size=128, int4_packing_format="plain"
    )
    hf_overrides = {
        "quantization_config_dict_json": json.dumps(
            config_to_dict(torchao_quant_config)
        )
    }

    # Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
    # have meta kernel implemented yet, can remove this flag after that is implemented
    with vllm_runner(
        model_name=model_name,
        quantization="torchao",
        dtype="bfloat16",
361
        pt_load_map_location=f"{DEVICE_TYPE}:0",
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        hf_overrides=hf_overrides,
        enforce_eager=True,
    ) as llm:

        def has_int4_preshuffled_tensor_weight(model):
            return isinstance(
                model.model.decoder.layers[0].self_attn.qkv_proj.weight,
                Int4PreshuffledTensor,
            )

        def get_weight_attrs(model):
            weight = model.model.decoder.layers[0].self_attn.qkv_proj.weight
            return [
                weight.requires_grad,
                weight.input_dim,
                weight.output_dim,
                hasattr(weight, "weight_loader"),
            ]

        llm_engine = llm.get_llm().llm_engine
        has_int4_preshuffled_tensor = any(
            llm_engine.apply_model(has_int4_preshuffled_tensor_weight)
        )
        weight_attrs = llm_engine.apply_model(get_weight_attrs)[0]

        # making sure we are using Int4PreshuffledTensor on H100 GPU, when
        # fbgemm_gpu_genai
        # library is installed, otherwise it should be using Int4Tensor
        if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
            assert has_int4_preshuffled_tensor
        else:
            assert not has_int4_preshuffled_tensor

        assert weight_attrs == [False, 1, 0, True]
396
        output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
397

398
399
400
        assert output


Driss Guessous's avatar
Driss Guessous committed
401
402
if __name__ == "__main__":
    pytest.main([__file__])