test_tensorizer.py 12 KB
Newer Older
1
import gc
2
3
import json
import os
4
import pathlib
5
6
7
import subprocess
from unittest.mock import MagicMock, patch

8
import openai
9
import pytest
10
import torch
11
from huggingface_hub import snapshot_download
12
from tensorizer import EncryptionParams
13
14

from vllm import SamplingParams
15
from vllm.engine.arg_utils import EngineArgs
16
# yapf conflicts with isort for this docstring
17
18
19
20
21
22
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
23
24
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
25
26
# yapf: enable
from vllm.utils import import_from_path
27

28
from ..conftest import VllmRunner
29
from ..utils import VLLM_PATH, RemoteOpenAIServer
30
from .conftest import retry_until_skip
31

32
EXAMPLES_PATH = VLLM_PATH / "examples"
33

34
35
36
37
38
39
40
41
42
43
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
44
45
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
46

47

48
49
50
51
52
53
54
def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

55

56
57
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
58
59
60
61
62
63
64
        .model \
        .llm_engine \
        .model_executor \
        .driver_worker \
        .model_runner \
        .model

65
66
67
68
69
70

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
71
72


73
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
74
75
76
77
78
79
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
80
                                  quant_method=mock_linear_method)
81
82

    mock_agent.assert_called_once_with(tensorizer_config,
83
                                       quant_method=mock_linear_method)
84
85
86
87
88
89
90
91
92
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

93
    with vllm_runner(model_ref,
94
95
96
97
98
99
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=tensorized_path,
                         num_readers=1,
                         s3_endpoint="object.ord1.coreweave.com",
                     )) as loaded_hf_model:
100
101
        deserialized_outputs = loaded_hf_model.generate(
            prompts, sampling_params)
102
        # noqa: E501
103

104
        assert deserialized_outputs
105
106
107
108
109


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
110
111
112
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
113
114
        write_keyfile(key_path)

115
        outputs = vllm_model.generate(prompts, sampling_params)
116

117
118
        config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
                                                  encryption_keyfile=key_path)
119
        serialize_vllm_model(get_torch_model(vllm_model),
120
                             config_for_serializing)
121
122
123
124

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

125
126
127
128
    with vllm_runner(model_ref,
                     load_format="tensorizer",
                     model_loader_extra_config=config_for_deserializing
                     ) as loaded_vllm_model:  # noqa: E501
129

130
131
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
132
        # noqa: E501
133

134
        assert outputs == deserialized_outputs
135
136
137
138


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
139
140
141
142
143
144
145
146
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

147
    with vllm_runner(model_ref,
148
149
150
151
152
                     load_format="tensorizer",
                     model_loader_extra_config=TensorizerConfig(
                         tensorizer_uri=model_path,
                         num_readers=1,
                     )) as loaded_hf_model:
153
154
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
155

156
        assert outputs == deserialized_outputs
157
158
159


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
160
161
162
163
    multilora_inference = import_from_path(
        "examples.multilora_inference",
        EXAMPLES_PATH / "multilora_inference.py",
    )
164
165
166

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
167
    test_prompts = multilora_inference.create_test_prompts(lora_path)
168
169

    # Serialize model before deserializing and binding LoRA adapters
170
171
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
172

173
        serialize_vllm_model(get_torch_model(vllm_model),
174
                             TensorizerConfig(tensorizer_uri=model_path))
175

176
    with vllm_runner(
177
178
179
180
181
182
183
184
185
186
187
188
            model_ref,
            load_format="tensorizer",
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=model_path,
                num_readers=1,
            ),
            enable_lora=True,
            max_loras=1,
            max_lora_rank=8,
            max_cpu_loras=2,
            max_num_seqs=50,
            max_model_len=1000,
189
    ) as loaded_vllm_model:
190
191
        multilora_inference.process_requests(
            loaded_vllm_model.model.llm_engine, test_prompts)
192

193
        assert loaded_vllm_model
194
195
196


def test_load_without_tensorizer_load_format(vllm_runner):
197
    model = None
198
    with pytest.raises(ValueError):
199
        model = vllm_runner(
200
201
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
202
203
204
    del model
    gc.collect()
    torch.cuda.empty_cache()
205
206
207


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
208
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
209
    ## Serialize model
210
211
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
212

213
        serialize_vllm_model(get_torch_model(vllm_model),
214
                             TensorizerConfig(tensorizer_uri=model_path))
215

216
217
218
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
219

220
221
    ## Start OpenAI API server
    openai_args = [
222
223
224
225
226
        "--dtype",
        "float16",
        "--load-format",
        "tensorizer",
        "--model-loader-extra-config",
227
        json.dumps(model_loader_extra_config),
228
229
    ]

230
    with RemoteOpenAIServer(model_ref, openai_args) as server:
231
        print("Server ready.")
232

233
234
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
235
236
237
                                               prompt="Hello, my name is",
                                               max_tokens=5,
                                               temperature=0.0)
238

239
240
241
242
243
244
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
245
246
247


def test_raise_value_error_on_invalid_load_format(vllm_runner):
248
    model = None
249
    with pytest.raises(ValueError):
250
        model = vllm_runner(
251
252
253
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
254
255
256
    del model
    gc.collect()
    torch.cuda.empty_cache()
257
258


259
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
260
def test_tensorizer_with_tp_path_without_template(vllm_runner):
261
262
263
264
265
266
267
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
268
269
270
271
272
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
273
            tensor_parallel_size=2,
274
            disable_custom_all_reduce=True,
275
        )
276

277

278
279
280
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
        vllm_runner, tmp_path):
281
282
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
283
284
285
286
287
288
289
    with vllm_runner(
            model_ref,
            disable_custom_all_reduce=True,
            enforce_eager=True,
    ) as base_model:
        outputs = base_model.generate(prompts, sampling_params)
        base_model.model.llm_engine.model_executor.shutdown()
290
291
292
293
294
295
296
297
298
299
300
301

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
302
303
304
305
306
            model=model_ref,
            tensor_parallel_size=2,
            disable_custom_all_reduce=True,
            enforce_eager=True,
        ),
307
308
309
310
311
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"

312
313
314
315
316
317
318
    with vllm_runner(
            model_ref,
            tensor_parallel_size=2,
            load_format="tensorizer",
            disable_custom_all_reduce=True,
            enforce_eager=True,
            model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
319
320
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
321
322
323

    assert outputs == deserialized_outputs

324

325
@retry_until_skip(3)
326
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
327
328
    gc.collect()
    torch.cuda.empty_cache()
329
330
331
332
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

333
334
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
335
        serialize_vllm_model(get_torch_model(vllm_model), config)
336

337
        assert is_vllm_tensorized(config)
338

339
    with vllm_runner(model_ref,
340
341
                     load_format="tensorizer",
                     model_loader_extra_config=config) as loaded_vllm_model:
342
343
        deserialized_outputs = loaded_vllm_model.generate(
            prompts, sampling_params)
344
        # noqa: E501
345

346
        assert outputs == deserialized_outputs