test_tensorizer.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import gc
6
import json
7
import os
8
import pathlib
9
import subprocess
10
11
import sys
from typing import Any
12
13

import pytest
14
import torch
15

16
17
import vllm.model_executor.model_loader.tensorizer
from vllm import LLM, SamplingParams
18
from vllm.engine.arg_utils import EngineArgs
19
20
21
22
23
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         open_stream,
24
                                                         tensorize_vllm_model)
25

26
27
from vllm.model_executor.model_loader.tensorizer_loader import (
    BLACKLISTED_TENSORIZER_ARGS)
28
# yapf: enable
29
from vllm.utils import PlaceholderModule
30

31
from ..utils import VLLM_PATH, RemoteOpenAIServer, models_path_prefix
32
from .conftest import DummyExecutor, assert_from_collective_rpc
33

34
try:
35
    import tensorizer
36
37
38
39
40
    from tensorizer import EncryptionParams
except ImportError:
    tensorizer = PlaceholderModule("tensorizer")  # type: ignore[assignment]
    EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")

41
42
43
44
45

class TensorizerCaughtError(Exception):
    pass


46
EXAMPLES_PATH = VLLM_PATH / "examples"
47

48
49
pytest_plugins = "pytest_asyncio",

50
51
52
53
54
55
56
57
58
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

def patch_init_and_catch_error(self, obj, method_name,
                               expected_error: type[Exception]):
    original = getattr(obj, method_name, None)
    if original is None:
        raise ValueError("Method '{}' not found.".format(method_name))

    def wrapper(*args, **kwargs):
        try:
            return original(*args, **kwargs)
        except expected_error as err:
            raise TensorizerCaughtError from err

    setattr(obj, method_name, wrapper)

    self.load_model()


def assert_specific_tensorizer_error_is_raised(
    executor,
    obj: Any,
    method_name: str,
    expected_error: type[Exception],
):
    with pytest.raises(TensorizerCaughtError):
        executor.collective_rpc(patch_init_and_catch_error,
                                args=(
                                    obj,
                                    method_name,
                                    expected_error,
                                ))
90

91

92
93
94
95
96
97
98
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

99

100
101
102
103
104
def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
105
106
107
108


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
109
        model_ref, vllm_runner, tmp_path, model_path):
110
    args = EngineArgs(model=model_ref)
111
    with vllm_runner(model_ref) as vllm_model:
112
        key_path = tmp_path / model_ref / "model.key"
113
114
        write_keyfile(key_path)

115
        outputs = vllm_model.generate(prompts, sampling_params)
116

117
118
    config_for_serializing = TensorizerConfig(tensorizer_uri=str(model_path),
                                              encryption_keyfile=str(key_path))
119

120
    tensorize_vllm_model(args, config_for_serializing)
121

122
123
    config_for_deserializing = TensorizerConfig(
        tensorizer_uri=str(model_path), encryption_keyfile=str(key_path))
124

125
126
127
128
    with vllm_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config_for_deserializing
                     ) as loaded_vllm_model:  # noqa: E501
129

130
131
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
132
        # noqa: E501
133

134
        assert outputs == deserialized_outputs
135
136
137


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
138
139
                                                tmp_path, model_ref,
                                                model_path):
140
141
142
143
144
145
146
    with hf_runner(model_ref) as hf_model:
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

147
    with vllm_runner(model_ref,
148
149
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
150
                         tensorizer_uri=str(model_path),
151
152
                         num_readers=1,
                     )) as loaded_hf_model:
153
154
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
155

156
        assert outputs == deserialized_outputs
157
158


159
def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
160
    model = None
161
    try:
162
        model = vllm_runner(
163
164
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
165
166
167
168
169
170
171
172
173
174
175
176
    except RuntimeError:
        out, err = capfd.readouterr()
        combined_output = out + err
        assert ("ValueError: Model loader extra config "
                "is not supported for load "
                "format LoadFormat.AUTO") in combined_output
    finally:
        del model
        gc.collect()
        torch.cuda.empty_cache()


177
178
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
                                                  model_ref):
179
    model = None
180
    try:
181
        model = vllm_runner(
182
183
184
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
185
186
187
188
189
190
191
192
193
194
    except RuntimeError:
        out, err = capfd.readouterr()

        combined_output = out + err
        assert ("ValueError: Model loader extra config is not supported "
                "for load format LoadFormat.SAFETENSORS") in combined_output
    finally:
        del model
        gc.collect()
        torch.cuda.empty_cache()
195
196


197
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
198
199
def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
    try:
200
        model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
201
        # tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
zhuwenwen's avatar
zhuwenwen committed
202
        tensorized_path = f"{model_ref}/fp16/model.tensors"
203
204
205
206

        vllm_runner(
            model_ref,
            load_format="tensorizer",
207
208
209
210
211
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
212
            tensor_parallel_size=2,
213
            disable_custom_all_reduce=True,
214
        )
215
216
217
218
219
220
221
    except RuntimeError:
        out, err = capfd.readouterr()
        combined_output = out + err
        assert ("ValueError: For a sharded model, tensorizer_uri "
                "should include a string format template like '%04d' "
                "to be formatted with the rank "
                "of the shard") in combined_output
222

223

224
225
226
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
        vllm_runner, tmp_path):
227
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
228
    # record outputs from un-sharded un-tensorized model
229
230
231
232
233
234
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
235
236

    # load model with two shards and serialize with encryption
237
    model_path = str(tmp_path / model_ref / "model-%02d.tensors")
238
239
240
241
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
242
        encryption_keyfile=str(key_path),
243
244
245
246
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
247
248
249
250
251
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
252
253
254
255
256
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

257
258
259
260
261
262
263
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
264
265
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
266
267
268

    assert outputs == deserialized_outputs

269

270
@pytest.mark.flaky(reruns=3)
271
272
def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner,
                                                tmp_path, model_path):
273
274
    gc.collect()
    torch.cuda.empty_cache()
275

276
    config = TensorizerConfig(tensorizer_uri=str(model_path))
277
    args = EngineArgs(model=model_ref)
278

279
280
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
281

282
283
    tensorize_vllm_model(args, config)
    assert is_vllm_tensorized(config)
284

285
    with vllm_runner(model_ref,
286
287
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
288
289
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
290
        # noqa: E501
291

292
        assert outputs == deserialized_outputs
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532


def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref):
    # For backwards compatibility, ensure Tensorizer can be still be loaded
    # for inference by passing the model reference name, not a local/S3 dir,
    # and the location of the model tensors

    model_dir = just_serialize_model_tensors

    extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"}

    ## Start OpenAI API server
    args = [
        "--load-format",
        "tensorizer",
        "--model-loader-extra-config",
        json.dumps(extra_config),
    ]

    with RemoteOpenAIServer(model_ref, args):
        # This test only concerns itself with being able to load the model
        # and successfully initialize the server
        pass


def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path):

    serialization_params = {
        "limit_cpu_concurrency": 2,
    }
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path),
                              serialization_kwargs=serialization_params)
    llm = LLM(model=model_ref, )

    def serialization_test(self, *args, **kwargs):
        # This is performed in the ephemeral worker process, so monkey-patching
        # will actually work, and cleanup is guaranteed so don't
        # need to reset things

        original_dict = serialization_params
        to_compare = {}

        original = tensorizer.serialization.TensorSerializer.__init__

        def tensorizer_serializer_wrapper(self, *args, **kwargs):
            nonlocal to_compare
            to_compare = kwargs.copy()
            return original(self, *args, **kwargs)

        tensorizer.serialization.TensorSerializer.__init__ = (
            tensorizer_serializer_wrapper)

        tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"])
        self.save_tensorized_model(tensorizer_config=tensorizer_config, )
        return to_compare | original_dict == to_compare

    kwargs = {"tensorizer_config": config.to_serializable()}

    assert assert_from_collective_rpc(llm, serialization_test, kwargs)


def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(
        tmp_path, capfd):

    deserialization_kwargs = {
        "num_readers": "bar",  # illegal value
    }

    serialization_params = {
        "limit_cpu_concurrency": 2,
    }

    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path),
                              serialization_kwargs=serialization_params)

    args = EngineArgs(model=model_ref)
    tensorize_vllm_model(args, config)

    loader_tc = TensorizerConfig(
        tensorizer_uri=str(model_path),
        deserialization_kwargs=deserialization_kwargs,
    )

    engine_args = EngineArgs(
        model="facebook/opt-125m",
        load_format="tensorizer",
        model_loader_extra_config=loader_tc.to_serializable(),
    )

    vllm_config = engine_args.create_engine_config()
    executor = DummyExecutor(vllm_config)

    assert_specific_tensorizer_error_is_raised(
        executor,
        tensorizer.serialization.TensorDeserializer,
        "__init__",
        TypeError,
    )


def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd):

    deserialization_kwargs = {
        "num_readers": 1,
    }

    serialization_params = {
        "limit_cpu_concurrency": 2,
    }

    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path),
                              serialization_kwargs=serialization_params)

    args = EngineArgs(model=model_ref)
    tensorize_vllm_model(args, config)

    stream_kwargs = {"mode": "foo"}

    loader_tc = TensorizerConfig(
        tensorizer_uri=str(model_path),
        deserialization_kwargs=deserialization_kwargs,
        stream_kwargs=stream_kwargs,
    )

    engine_args = EngineArgs(
        model="facebook/opt-125m",
        load_format="tensorizer",
        model_loader_extra_config=loader_tc.to_serializable(),
    )

    vllm_config = engine_args.create_engine_config()
    executor = DummyExecutor(vllm_config)

    assert_specific_tensorizer_error_is_raised(
        executor,
        vllm.model_executor.model_loader.tensorizer,
        "open_stream",
        ValueError,
    )


@pytest.mark.asyncio
async def test_serialize_and_serve_entrypoints(tmp_path):
    model_ref = "facebook/opt-125m"

    suffix = "test"
    try:
        result = subprocess.run([
            sys.executable,
            f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
            model_ref, "serialize", "--serialized-directory",
            str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
            '{"limit_cpu_concurrency": 4}'
        ],
                                check=True,
                                capture_output=True,
                                text=True)
    except subprocess.CalledProcessError as e:
        print("Tensorizing failed.")
        print("STDOUT:\n", e.stdout)
        print("STDERR:\n", e.stderr)
        raise

    assert "Successfully serialized" in result.stdout

    # Next, try to serve with vllm serve
    model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors"

    model_loader_extra_config = {
        "tensorizer_uri": str(model_uri),
        "stream_kwargs": {
            "force_http": False,
        },
        "deserialization_kwargs": {
            "verify_hash": True,
            "num_readers": 8,
        }
    }

    cmd = [
        "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost",
        "--load-format", "tensorizer", model_ref,
        "--model-loader-extra-config",
        json.dumps(model_loader_extra_config, indent=2)
    ]

    proc = await asyncio.create_subprocess_exec(
        sys.executable,
        *cmd,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.STDOUT,
    )

    assert proc.stdout is not None
    fut = proc.stdout.readuntil(b"Application startup complete.")

    try:
        await asyncio.wait_for(fut, 180)
    except asyncio.TimeoutError:
        pytest.fail("Server did not start successfully")
    finally:
        proc.terminate()
    await proc.communicate()


@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS)
def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd,
                                           illegal_value):

    serialization_params = {
        "limit_cpu_concurrency": 2,
    }

    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path),
                              serialization_kwargs=serialization_params)

    args = EngineArgs(model=model_ref)
    tensorize_vllm_model(args, config)

    loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"}

    try:
        vllm_runner(
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=loader_tc,
        )
    except RuntimeError:
        out, err = capfd.readouterr()
        combined_output = out + err
        assert (f"ValueError: {illegal_value} is not an allowed "
                f"Tensorizer argument.") in combined_output