test_tensorizer.py 12.1 KB
Newer Older
1
import gc
2
3
import json
import os
4
import pathlib
5
6
7
import subprocess
from unittest.mock import MagicMock, patch

8
import openai
9
import pytest
10
11
import torch
from tensorizer import EncryptionParams
12
13

from vllm import SamplingParams
14
from vllm.engine.arg_utils import EngineArgs
15
16
17
18
19
20
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
21
22
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
23

24
from ..conftest import VllmRunner
25
from ..utils import RemoteOpenAIServer
26
from .conftest import retry_until_skip
27

28
29
30
# yapf conflicts with isort for this docstring


31
32
33
34
35
36
37
38
39
40
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
41
42
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
43

44

45
46
47
48
49
50
51
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

52

53
54
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
55
56
57
58
59
60
61
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model

62
63
64
65
66
67

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
68
69


70
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
71
72
73
74
75
76
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
77
                                  quant_method=mock_linear_method)
78
79

    mock_agent.assert_called_once_with(tensorizer_config,
80
                                       quant_method=mock_linear_method)
81
82
83
84
85
86
87
88
89
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

90
    with vllm_runner(model_ref,
91
92
93
94
95
96
97
98
99
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
        deserialized_outputs = loaded_hf_model.generate(prompts,
                                                        sampling_params)
        # noqa: E501
100

101
        assert deserialized_outputs
102
103
104
105
106


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
107
108
109
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
110
111
        write_keyfile(key_path)

112
        outputs = vllm_model.generate(prompts, sampling_params)
113

114
115
116
117
118
        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_vllm_model(get_torch_model(vllm_model),
119
                             config_for_serializing)
120
121
122
123

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

124
    with vllm_runner(
125
126
127
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=config_for_deserializing) as loaded_vllm_model:  # noqa: E501
128

129
130
131
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501
132

133
        assert outputs == deserialized_outputs
134
135
136
137


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
138
139
140
141
142
143
144
145
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

146
    with vllm_runner(model_ref,
147
148
149
150
151
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
152
153
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
154

155
        assert outputs == deserialized_outputs
156
157
158
159
160
161
162
163
164
165
166
167
168


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
169
170
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
171

172
        serialize_vllm_model(get_torch_model(vllm_model),
173
                             TensorizerConfig(tensorizer_uri=model_path))
174

175
    with vllm_runner(
176
177
178
179
180
181
182
183
184
185
186
187
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
188
189
    ) as loaded_vllm_model:
        process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
190

191
        assert loaded_vllm_model
192
193
194


def test_load_without_tensorizer_load_format(vllm_runner):
195
    model = None
196
    with pytest.raises(ValueError):
197
        model = vllm_runner(
198
199
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
200
201
202
    del model
    gc.collect()
    torch.cuda.empty_cache()
203
204
205


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
206
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
207
    ## Serialize model
208
209
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
210

211
        serialize_vllm_model(get_torch_model(vllm_model),
212
                             TensorizerConfig(tensorizer_uri=model_path))
213

214
215
216
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
217

218
219
    ## Start OpenAI API server
    openai_args = [
220
        "--dtype", "float16", "--load-format",
221
        "tensorizer", "--model-loader-extra-config",
222
        json.dumps(model_loader_extra_config),
223
224
    ]

225
    with RemoteOpenAIServer(model_ref, openai_args) as server:
226
        print("Server ready.")
227

228
229
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
230
231
232
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
233

234
235
236
237
238
239
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
240
241
242


def test_raise_value_error_on_invalid_load_format(vllm_runner):
243
    model = None
244
    with pytest.raises(ValueError):
245
        model = vllm_runner(
246
247
248
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
249
250
251
    del model
    gc.collect()
    torch.cuda.empty_cache()
252
253


254
255
256
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
257
258
259
260
261
262
263
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
264
265
266
267
268
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
269
            tensor_parallel_size=2,
270
            disable_custom_all_reduce=True,
271
        )
272

273

274
275
276
277
278
279
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
                                                                    tmp_path):
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
280
281
282
283
284
285
286
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
287
288
289
290
291
292
293
294
295
296
297
298

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
299
300
301
302
303
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
304
305
306
307
308
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

309
310
311
312
313
314
315
316
317
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
318
319
320

    assert outputs == deserialized_outputs

321

322
323

@retry_until_skip(3)
324
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
325
326
    gc.collect()
    torch.cuda.empty_cache()
327
328
329
330
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

331
332
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
333
        serialize_vllm_model(get_torch_model(vllm_model), config)
334

335
        assert is_vllm_tensorized(config)
336

337
    with vllm_runner(model_ref,
338
339
340
341
342
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts,
                                                          sampling_params)
        # noqa: E501
343

344
        assert outputs == deserialized_outputs