test_tensorizer.py 16.3 KB
Newer Older
1
import gc
2
3
import json
import os
4
import pathlib
5
6
7
import subprocess
from unittest.mock import MagicMock, patch

8
import openai
9
import pytest
10
import torch
11
from huggingface_hub import snapshot_download
12
from tensorizer import EncryptionParams
13

zhuwenwen's avatar
zhuwenwen committed
14
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15
from vllm.engine.arg_utils import EngineArgs
16
# yapf conflicts with isort for this docstring
17
18
19
20
21
22
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
23
24
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
25
26
# yapf: enable
from vllm.utils import import_from_path
27

28
from ..conftest import VllmRunner
29
from ..utils import VLLM_PATH, RemoteOpenAIServer
30
from .conftest import retry_until_skip
zhuwenwen's avatar
zhuwenwen committed
31
from ..utils import RemoteOpenAIServer, models_path_prefix
32

33
EXAMPLES_PATH = VLLM_PATH / "examples"
34

35
36
37
38
39
40
41
42
43
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

44
model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
45
46
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
47

48

49
50
51
52
53
54
55
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

56

57
58
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
59
60
61
62
63
64
65
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model

66
67
68
69
70
71

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
72
73


74
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
75
76
77
78
79
80
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
81
                                  quant_method=mock_linear_method)
82
83

    mock_agent.assert_called_once_with(tensorizer_config,
84
                                       quant_method=mock_linear_method)
85
86
87
88
89
90
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
91
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
92
    tensorized_path = f"{model_ref}/fp16/model.tensors"
93

94
    with vllm_runner(model_ref,
95
96
97
98
99
100
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
101
102
        deserialized_outputs = loaded_hf_model.generate(
            prompts, sampling_params)
103
        # noqa: E501
104

105
        assert deserialized_outputs
106
107
108
109
110


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
111
112
113
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
114
115
        write_keyfile(key_path)

116
        outputs = vllm_model.generate(prompts, sampling_params)
117

118
119
        config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
                                                  encryption_keyfile=key_path)
120
        serialize_vllm_model(get_torch_model(vllm_model),
121
                             config_for_serializing)
122
123
124
125

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

126
127
128
129
    with vllm_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config_for_deserializing
                     ) as loaded_vllm_model:  # noqa: E501
130

131
132
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
133
        # noqa: E501
134

135
        assert outputs == deserialized_outputs
136
137
138
139


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
140
141
142
143
144
145
146
147
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

148
    with vllm_runner(model_ref,
149
150
151
152
153
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
154
155
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
156

157
        assert outputs == deserialized_outputs
158
159


zhuwenwen's avatar
zhuwenwen committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def create_test_prompts(
        lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.

    2 requests for base model, 4 requests for the LoRA. We define 2
    different LoRA adapters (using the same model for demo purposes).
    Since we also set `max_loras=1`, the expectation is that the requests
    with the second LoRA adapter will be ran after all requests with the
    first adapter have finished.
    """
    return [
        ("A robot may not injure a human being",
         SamplingParams(temperature=0.0,
                        logprobs=1,
                        prompt_logprobs=1,
                        max_tokens=128), None),
        ("To be or not to be,",
         SamplingParams(temperature=0.8,
                        top_k=5,
                        presence_penalty=0.2,
                        max_tokens=128), None),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora2", 2, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
    ]


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


241
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
242
243
244
245
    multilora_inference = import_from_path(
        "examples.multilora_inference",
        EXAMPLES_PATH / "multilora_inference.py",
    )
246
247

    model_ref = "meta-llama/Llama-2-7b-hf"
248
249
    # lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    lora_path = os.path.join(models_path_prefix, "yard1/llama-2-7b-sql-lora-test")
250
    test_prompts = multilora_inference.create_test_prompts(lora_path)
251
252

    # Serialize model before deserializing and binding LoRA adapters
253
254
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
255

256
        serialize_vllm_model(get_torch_model(vllm_model),
257
                             TensorizerConfig(tensorizer_uri=model_path))
258

259
    with vllm_runner(
260
261
262
263
264
265
266
267
268
269
270
271
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
272
    ) as loaded_vllm_model:
273
274
        multilora_inference.process_requests(
            loaded_vllm_model.model.llm_engine, test_prompts)
275

276
        assert loaded_vllm_model
277
278
279


def test_load_without_tensorizer_load_format(vllm_runner):
280
    model = None
281
    with pytest.raises(ValueError):
282
        model = vllm_runner(
283
284
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
285
286
287
    del model
    gc.collect()
    torch.cuda.empty_cache()
288
289
290


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
291
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
292
    ## Serialize model
293
294
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
295

296
        serialize_vllm_model(get_torch_model(vllm_model),
297
                             TensorizerConfig(tensorizer_uri=model_path))
298

299
300
301
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
302

303
304
    ## Start OpenAI API server
    openai_args = [
305
306
307
308
309
        "--dtype",
        "float16",
        "--load-format",
        "tensorizer",
        "--model-loader-extra-config",
310
        json.dumps(model_loader_extra_config),
311
312
    ]

313
    with RemoteOpenAIServer(model_ref, openai_args) as server:
314
        print("Server ready.")
315

316
317
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
318
319
320
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
321

322
323
324
325
326
327
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
328
329
330


def test_raise_value_error_on_invalid_load_format(vllm_runner):
331
    model = None
332
    with pytest.raises(ValueError):
333
        model = vllm_runner(
334
335
336
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
337
338
339
    del model
    gc.collect()
    torch.cuda.empty_cache()
340
341


342
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
343
def test_tensorizer_with_tp_path_without_template(vllm_runner):
344
    with pytest.raises(ValueError):
345
        model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
zhuwenwen's avatar
zhuwenwen committed
346
        tensorized_path = f"{model_ref}/fp16/model.tensors"
347
348
349
350

        vllm_runner(
            model_ref,
            load_format="tensorizer",
351
352
353
354
355
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
356
            tensor_parallel_size=2,
357
            disable_custom_all_reduce=True,
358
        )
359

360

361
362
363
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
        vllm_runner, tmp_path):
364
    model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
365
    # record outputs from un-sharded un-tensorized model
366
367
368
369
370
371
372
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
373
374
375
376
377
378
379
380
381
382
383
384

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
385
386
387
388
389
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
390
391
392
393
394
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

395
396
397
398
399
400
401
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
402
403
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
404
405
406

    assert outputs == deserialized_outputs

407

408
@retry_until_skip(3)
409
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
410
411
    gc.collect()
    torch.cuda.empty_cache()
412
    model_ref = os.path.join(models_path_prefix, "facebook/opt-125m")
413
414
415
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

416
417
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
418
        serialize_vllm_model(get_torch_model(vllm_model), config)
419

420
        assert is_vllm_tensorized(config)
421

422
    with vllm_runner(model_ref,
423
424
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
425
426
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
427
        # noqa: E501
428

429
        assert outputs == deserialized_outputs