test_tensorizer.py 16.4 KB
Newer Older
1
import gc
2
3
import json
import os
4
import pathlib
5
6
7
import subprocess
from unittest.mock import MagicMock, patch

8
import openai
9
import pytest
10
11
import torch
from tensorizer import EncryptionParams
12

zhuwenwen's avatar
zhuwenwen committed
13
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
14
from vllm.engine.arg_utils import EngineArgs
15
16
17
18
19
20
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
21
22
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
23

24
from ..conftest import VllmRunner
25
from ..utils import RemoteOpenAIServer, models_path_prefix
26
from .conftest import retry_until_skip
27

zhuwenwen's avatar
zhuwenwen committed
28
29
30
from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest

31
32
33
# yapf conflicts with isort for this docstring


34
35
36
37
38
39
40
41
42
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

43
model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
44
45
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
46

47

48
49
50
51
52
53
54
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

55

56
57
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
58
59
60
61
62
63
64
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model

65
66
67
68
69
70

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
71
72


73
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
74
75
76
77
78
79
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
80
                                  quant_method=mock_linear_method)
81
82

    mock_agent.assert_called_once_with(tensorizer_config,
83
                                       quant_method=mock_linear_method)
84
85
86
87
88
89
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
90
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
91
    tensorized_path = f"{model_ref}/fp16/model.tensors"
92

93
    with vllm_runner(model_ref,
94
95
96
97
98
99
100
101
102
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
        deserialized_outputs = loaded_hf_model.generate(prompts,
                                                        sampling_params)
        # noqa: E501
103

104
        assert deserialized_outputs
105
106
107
108
109


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
110
111
112
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
113
114
        write_keyfile(key_path)

115
        outputs = vllm_model.generate(prompts, sampling_params)
116

117
118
119
120
121
        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_vllm_model(get_torch_model(vllm_model),
122
                             config_for_serializing)
123
124
125
126

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

127
    with vllm_runner(
128
129
130
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=config_for_deserializing) as loaded_vllm_model:  # noqa: E501
131

132
133
134
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501
135

136
        assert outputs == deserialized_outputs
137
138
139
140


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
141
142
143
144
145
146
147
148
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

149
    with vllm_runner(model_ref,
150
151
152
153
154
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
155
156
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
157

158
        assert outputs == deserialized_outputs
159
160


zhuwenwen's avatar
zhuwenwen committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def create_test_prompts(
        lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.

    2 requests for base model, 4 requests for the LoRA. We define 2
    different LoRA adapters (using the same model for demo purposes).
    Since we also set `max_loras=1`, the expectation is that the requests
    with the second LoRA adapter will be ran after all requests with the
    first adapter have finished.
    """
    return [
        ("A robot may not injure a human being",
         SamplingParams(temperature=0.0,
                        logprobs=1,
                        prompt_logprobs=1,
                        max_tokens=128), None),
        ("To be or not to be,",
         SamplingParams(temperature=0.8,
                        top_k=5,
                        presence_penalty=0.2,
                        max_tokens=128), None),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora2", 2, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
    ]


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


242
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
243
    # from huggingface_hub import snapshot_download
244

zhuwenwen's avatar
zhuwenwen committed
245
246
    # from examples.multilora_inference import (create_test_prompts,
    #                                           process_requests)
247

248
249
250
    model_ref = os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf")
    # lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    lora_path = os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
251
252
253
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
254
255
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
256

257
        serialize_vllm_model(get_torch_model(vllm_model),
258
                             TensorizerConfig(tensorizer_uri=model_path))
259

260
    with vllm_runner(
261
262
263
264
265
266
267
268
269
270
271
272
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
273
274
    ) as loaded_vllm_model:
        process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
275

276
        assert loaded_vllm_model
277
278
279


def test_load_without_tensorizer_load_format(vllm_runner):
280
    model = None
281
    with pytest.raises(ValueError):
282
        model = vllm_runner(
283
284
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
285
286
287
    del model
    gc.collect()
    torch.cuda.empty_cache()
288
289
290


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
291
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
292
    ## Serialize model
293
294
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
295

296
        serialize_vllm_model(get_torch_model(vllm_model),
297
                             TensorizerConfig(tensorizer_uri=model_path))
298

299
300
301
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
302

303
304
    ## Start OpenAI API server
    openai_args = [
305
        "--dtype", "float16", "--load-format",
306
        "tensorizer", "--model-loader-extra-config",
307
        json.dumps(model_loader_extra_config),
308
309
    ]

310
    with RemoteOpenAIServer(model_ref, openai_args) as server:
311
        print("Server ready.")
312

313
314
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
315
316
317
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
318

319
320
321
322
323
324
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
325
326
327


def test_raise_value_error_on_invalid_load_format(vllm_runner):
328
    model = None
329
    with pytest.raises(ValueError):
330
        model = vllm_runner(
331
332
333
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
334
335
336
    del model
    gc.collect()
    torch.cuda.empty_cache()
337
338


339
340
341
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
342
    with pytest.raises(ValueError):
343
        model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
344
        tensorized_path = f"{model_ref}/fp16/model.tensors"
345
346
347
348

        vllm_runner(
            model_ref,
            load_format="tensorizer",
349
350
351
352
353
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
354
            tensor_parallel_size=2,
355
            disable_custom_all_reduce=True,
356
        )
357

358

359
360
361
362
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
                                                                    tmp_path):
363
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
364
    # record outputs from un-sharded un-tensorized model
365
366
367
368
369
370
371
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
372
373
374
375
376
377
378
379
380
381
382
383

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
384
385
386
387
388
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
389
390
391
392
393
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

394
395
396
397
398
399
400
401
402
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
403
404
405

    assert outputs == deserialized_outputs

406

407
408

@retry_until_skip(3)
409
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
410
411
    gc.collect()
    torch.cuda.empty_cache()
412
    model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
413
414
415
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

416
417
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
418
        serialize_vllm_model(get_torch_model(vllm_model), config)
419

420
        assert is_vllm_tensorized(config)
421

422
    with vllm_runner(model_ref,
423
424
425
426
427
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501
428

429
        assert outputs == deserialized_outputs