test_tensorizer.py 12.2 KB
Newer Older
1
import gc
2
3
import json
import os
4
import pathlib
5
6
7
import subprocess
from unittest.mock import MagicMock, patch

8
import openai
9
import pytest
10
import torch
11
from huggingface_hub import snapshot_download
12
13

from vllm import SamplingParams
14
from vllm.engine.arg_utils import EngineArgs
15
# yapf conflicts with isort for this docstring
16
17
18
19
20
21
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
22
23
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
24
# yapf: enable
25
from vllm.utils import PlaceholderModule, import_from_path
26

27
from ..conftest import VllmRunner
28
from ..utils import VLLM_PATH, RemoteOpenAIServer
29
from .conftest import retry_until_skip
30

31
32
33
34
35
36
try:
    from tensorizer import EncryptionParams
except ImportError:
    tensorizer = PlaceholderModule("tensorizer")  # type: ignore[assignment]
    EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")

37
EXAMPLES_PATH = VLLM_PATH / "examples"
38

39
40
41
42
43
44
45
46
47
48
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
49
50
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
51

52

53
54
55
56
57
58
59
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

60

61
62
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
63
64
65
66
67
68
69
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model

70
71
72
73
74
75

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
76
77


78
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
79
80
81
82
83
84
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
85
                                  quant_method=mock_linear_method)
86
87

    mock_agent.assert_called_once_with(tensorizer_config,
88
                                       quant_method=mock_linear_method)
89
90
91
92
93
94
95
96
97
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

98
    with vllm_runner(model_ref,
99
100
101
102
103
104
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
105
106
        deserialized_outputs = loaded_hf_model.generate(
            prompts, sampling_params)
107
        # noqa: E501
108

109
        assert deserialized_outputs
110
111
112
113
114


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
115
116
117
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
118
119
        write_keyfile(key_path)

120
        outputs = vllm_model.generate(prompts, sampling_params)
121

122
123
        config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
                                                  encryption_keyfile=key_path)
124
        serialize_vllm_model(get_torch_model(vllm_model),
125
                             config_for_serializing)
126
127
128
129

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

130
131
132
133
    with vllm_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config_for_deserializing
                     ) as loaded_vllm_model:  # noqa: E501
134

135
136
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
137
        # noqa: E501
138

139
        assert outputs == deserialized_outputs
140
141
142
143


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
144
145
146
147
148
149
150
151
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

152
    with vllm_runner(model_ref,
153
154
155
156
157
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
158
159
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
160

161
        assert outputs == deserialized_outputs
162
163
164


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
165
    multilora_inference = import_from_path(
166
167
        "examples.offline_inference.multilora_inference",
        EXAMPLES_PATH / "offline_inference/multilora_inference.py",
168
    )
169
170
171

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
172
    test_prompts = multilora_inference.create_test_prompts(lora_path)
173
174

    # Serialize model before deserializing and binding LoRA adapters
175
176
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
177

178
        serialize_vllm_model(get_torch_model(vllm_model),
179
                             TensorizerConfig(tensorizer_uri=model_path))
180

181
    with vllm_runner(
182
183
184
185
186
187
188
189
190
191
192
193
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
194
    ) as loaded_vllm_model:
195
196
        multilora_inference.process_requests(
            loaded_vllm_model.model.llm_engine, test_prompts)
197

198
        assert loaded_vllm_model
199
200
201


def test_load_without_tensorizer_load_format(vllm_runner):
202
    model = None
203
    with pytest.raises(ValueError):
204
        model = vllm_runner(
205
206
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
207
208
209
    del model
    gc.collect()
    torch.cuda.empty_cache()
210
211
212


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
213
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
214
    ## Serialize model
215
216
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
217

218
        serialize_vllm_model(get_torch_model(vllm_model),
219
                             TensorizerConfig(tensorizer_uri=model_path))
220

221
222
223
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
224

225
226
    ## Start OpenAI API server
    openai_args = [
227
228
229
230
231
        "--dtype",
        "float16",
        "--load-format",
        "tensorizer",
        "--model-loader-extra-config",
232
        json.dumps(model_loader_extra_config),
233
234
    ]

235
    with RemoteOpenAIServer(model_ref, openai_args) as server:
236
        print("Server ready.")
237

238
239
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
240
241
242
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
243

244
245
246
247
248
249
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
250
251
252


def test_raise_value_error_on_invalid_load_format(vllm_runner):
253
    model = None
254
    with pytest.raises(ValueError):
255
        model = vllm_runner(
256
257
258
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
259
260
261
    del model
    gc.collect()
    torch.cuda.empty_cache()
262
263


264
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
265
def test_tensorizer_with_tp_path_without_template(vllm_runner):
266
267
268
269
270
271
272
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
273
274
275
276
277
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
278
            tensor_parallel_size=2,
279
            disable_custom_all_reduce=True,
280
        )
281

282

283
284
285
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
        vllm_runner, tmp_path):
286
287
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
288
289
290
291
292
293
294
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
295
296
297
298
299
300
301
302
303
304
305
306

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
307
308
309
310
311
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
312
313
314
315
316
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

317
318
319
320
321
322
323
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
324
325
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
326
327
328

    assert outputs == deserialized_outputs

329

330
@retry_until_skip(3)
331
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
332
333
    gc.collect()
    torch.cuda.empty_cache()
334
335
336
337
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

338
339
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
340
        serialize_vllm_model(get_torch_model(vllm_model), config)
341

342
        assert is_vllm_tensorized(config)
343

344
    with vllm_runner(model_ref,
345
346
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
347
348
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
349
        # noqa: E501
350

351
        assert outputs == deserialized_outputs