test_tensorizer.py 16.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
5
import json
import os
6
import pathlib
7
import subprocess
8
from functools import partial
9
10
from unittest.mock import MagicMock, patch

11
import openai
12
import pytest
13
import torch
14
from huggingface_hub import snapshot_download
15
from typing import List, Tuple, Optional
16

zhuwenwen's avatar
zhuwenwen committed
17
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
18
from vllm.engine.arg_utils import EngineArgs
19
# yapf conflicts with isort for this docstring
20
21
22
23
24
25
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
26
27
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
28
29
from vllm.lora.request import LoRARequest

30
# yapf: enable
31
from vllm.utils import PlaceholderModule, import_from_path
32

33
from ..utils import VLLM_PATH, RemoteOpenAIServer
34
from .conftest import retry_until_skip
zhuwenwen's avatar
zhuwenwen committed
35
from ..utils import RemoteOpenAIServer, models_path_prefix
36

37
38
39
40
41
42
try:
    from tensorizer import EncryptionParams
except ImportError:
    tensorizer = PlaceholderModule("tensorizer")  # type: ignore[assignment]
    EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")

43
EXAMPLES_PATH = VLLM_PATH / "examples"
44

45
46
47
48
49
50
51
52
53
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

54
model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
55
56
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
57

58

59
60
61
62
63
64
65
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

66

67
68
69
70
71
def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
72
73


74
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
75
76
77
78
79
80
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
81
                                  quant_method=mock_linear_method)
82
83

    mock_agent.assert_called_once_with(tensorizer_config,
84
                                       quant_method=mock_linear_method)
85
86
87
88
89
90
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
91
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
92
    tensorized_path = f"{model_ref}/fp16/model.tensors"
93

94
    with vllm_runner(model_ref,
95
96
97
98
99
100
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
101
102
        deserialized_outputs = loaded_hf_model.generate(
            prompts, sampling_params)
103
        # noqa: E501
104

105
        assert deserialized_outputs
106
107
108
109
110


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
111
112
113
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
114
115
        write_keyfile(key_path)

116
        outputs = vllm_model.generate(prompts, sampling_params)
117

118
119
        config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
                                                  encryption_keyfile=key_path)
120
121
122
123

        vllm_model.apply_model(
            partial(serialize_vllm_model,
                    tensorizer_config=config_for_serializing))
124
125
126
127

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

128
129
130
131
    with vllm_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config_for_deserializing
                     ) as loaded_vllm_model:  # noqa: E501
132

133
134
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
135
        # noqa: E501
136

137
        assert outputs == deserialized_outputs
138
139
140
141


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
142
143
144
145
146
147
148
149
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

150
    with vllm_runner(model_ref,
151
152
153
154
155
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
156
157
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
158

159
        assert outputs == deserialized_outputs
160
161


zhuwenwen's avatar
zhuwenwen committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def create_test_prompts(
        lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.

    2 requests for base model, 4 requests for the LoRA. We define 2
    different LoRA adapters (using the same model for demo purposes).
    Since we also set `max_loras=1`, the expectation is that the requests
    with the second LoRA adapter will be ran after all requests with the
    first adapter have finished.
    """
    return [
        ("A robot may not injure a human being",
         SamplingParams(temperature=0.0,
                        logprobs=1,
                        prompt_logprobs=1,
                        max_tokens=128), None),
        ("To be or not to be,",
         SamplingParams(temperature=0.8,
                        top_k=5,
                        presence_penalty=0.2,
                        max_tokens=128), None),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora2", 2, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
    ]


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


243
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
244
    multilora_inference = import_from_path(
245
246
        "examples.offline_inference.multilora_inference",
        EXAMPLES_PATH / "offline_inference/multilora_inference.py",
247
    )
248

249
    model_ref = os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf") 
250
251
    # lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    lora_path = os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
252
    test_prompts = multilora_inference.create_test_prompts(lora_path)
253
254

    # Serialize model before deserializing and binding LoRA adapters
255
    with vllm_runner(model_ref) as vllm_model:
256
        model_path = tmp_path / (model_ref + ".tensors")
257

258
259
260
261
        vllm_model.apply_model(
            partial(
                serialize_vllm_model,
                tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
262

263
    with vllm_runner(
264
265
266
267
268
269
270
271
272
273
274
275
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
276
    ) as loaded_vllm_model:
277
278
        multilora_inference.process_requests(
            loaded_vllm_model.model.llm_engine, test_prompts)
279

280
        assert loaded_vllm_model
281
282
283


def test_load_without_tensorizer_load_format(vllm_runner):
284
    model = None
285
    with pytest.raises(ValueError):
286
        model = vllm_runner(
287
288
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
289
290
291
    del model
    gc.collect()
    torch.cuda.empty_cache()
292
293
294


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
295
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
296
    ## Serialize model
297
    with vllm_runner(model_ref) as vllm_model:
298
        model_path = tmp_path / (model_ref + ".tensors")
299

300
301
302
303
        vllm_model.apply_model(
            partial(
                serialize_vllm_model,
                tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
304

305
306
307
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
308

309
310
    ## Start OpenAI API server
    openai_args = [
311
312
313
314
315
        "--dtype",
        "float16",
        "--load-format",
        "tensorizer",
        "--model-loader-extra-config",
316
        json.dumps(model_loader_extra_config),
317
318
    ]

319
    with RemoteOpenAIServer(model_ref, openai_args) as server:
320
        print("Server ready.")
321

322
323
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
324
325
326
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
327

328
329
330
331
332
333
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
334
335
336


def test_raise_value_error_on_invalid_load_format(vllm_runner):
337
    model = None
338
    with pytest.raises(ValueError):
339
        model = vllm_runner(
340
341
342
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
343
344
345
    del model
    gc.collect()
    torch.cuda.empty_cache()
346
347


348
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
349
def test_tensorizer_with_tp_path_without_template(vllm_runner):
350
    with pytest.raises(ValueError):
351
        model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
352
        tensorized_path = f"{model_ref}/fp16/model.tensors"
353
354
355
356

        vllm_runner(
            model_ref,
            load_format="tensorizer",
357
358
359
360
361
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
362
            tensor_parallel_size=2,
363
            disable_custom_all_reduce=True,
364
        )
365

366

367
368
369
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
        vllm_runner, tmp_path):
370
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
371
    # record outputs from un-sharded un-tensorized model
372
373
374
375
376
377
378
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
379
380
381
382
383
384
385
386
387
388
389
390

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
391
392
393
394
395
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
396
397
398
399
400
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

401
402
403
404
405
406
407
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
408
409
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
410
411
412

    assert outputs == deserialized_outputs

413

414
@retry_until_skip(3)
415
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
416
417
    gc.collect()
    torch.cuda.empty_cache()
418
    model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
419
420
421
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

422
423
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
424
425
426

        vllm_model.apply_model(
            partial(serialize_vllm_model, tensorizer_config=config))
427

428
        assert is_vllm_tensorized(config)
429

430
    with vllm_runner(model_ref,
431
432
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
433
434
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
435
        # noqa: E501
436

437
        assert outputs == deserialized_outputs