test_tensorizer.py 11.8 KB
Newer Older
1
2
import json
import os
3
import pathlib
4
5
6
import subprocess
from unittest.mock import MagicMock, patch

7
import openai
8
import pytest
9
10
import torch
from tensorizer import EncryptionParams
11
12

from vllm import SamplingParams
13
from vllm.engine.arg_utils import EngineArgs
14
15
16
17
18
19
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
20
21
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
22

23
from ..conftest import VllmRunner, cleanup
24
from ..utils import RemoteOpenAIServer
25

26
27
28
# yapf conflicts with isort for this docstring


29

30
31
32
33
34
35
36
37
38
39
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
40
41
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
42
43
44
45
46
47
48
49
50


def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
            .model \
            .llm_engine \
            .model_executor \
            .driver_worker \
            .model_runner \
            .model

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
65
66
67

@pytest.fixture(autouse=True)
def tensorizer_config():
68
    config = TensorizerConfig(tensorizer_uri="vllm")
69
70
71
    return config


72
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
73
74
75
76
77
78
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
79
                                  quant_method=mock_linear_method)
80
81

    mock_agent.assert_called_once_with(tensorizer_config,
82
                                       quant_method=mock_linear_method)
83
84
85
86
87
88
89
90
91
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

92
    with vllm_runner(model_ref,
93
94
95
96
97
                                  load_format="tensorizer",
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=tensorized_path,
                                      num_readers=1,
                                      s3_endpoint="object.ord1.coreweave.com",
98
                                  )) as loaded_hf_model:
99

100
        deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
101

102
        assert deserialized_outputs
103
104
105
106
107


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
108
109
110
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
111
112
        write_keyfile(key_path)

113
        outputs = vllm_model.generate(prompts, sampling_params)
114

115
116
117
118
119
120
121
        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_vllm_model(get_torch_model(vllm_model),
                            config_for_serializing)

122
123
124
125

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

126
    with vllm_runner(
127
128
        model_ref,
        load_format="tensorizer",
129
        model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
130

131
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
132

133
        assert outputs == deserialized_outputs
134
135
136
137


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
138
139
140
141
142
143
144
145
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

146
    with vllm_runner(model_ref,
147
                                  load_format="tensorizer",
148
149
150
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=model_path,
                                      num_readers=1,
151
                                  )) as loaded_hf_model:
152

153
154
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
155

156
        assert outputs == deserialized_outputs
157
158
159
160
161
162
163
164
165
166
167
168
169


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
170
171
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
172

173
        serialize_vllm_model(get_torch_model(vllm_model),
174
                            TensorizerConfig(tensorizer_uri=model_path))
175

176
    with vllm_runner(
177
178
        model_ref,
        load_format="tensorizer",
179
180
181
182
        model_loader_extra_config=TensorizerConfig(
            tensorizer_uri=model_path,
            num_readers=1,
        ),
183
184
185
186
187
188
        enable_lora=True,
        max_loras=1,
        max_lora_rank=8,
        max_cpu_loras=2,
        max_num_seqs=50,
        max_model_len=1000,
189
190
    ) as loaded_vllm_model:
        process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
191

192
        assert loaded_vllm_model
193
194
195
196


def test_load_without_tensorizer_load_format(vllm_runner):
    with pytest.raises(ValueError):
197
198
199
        vllm_runner(
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
200
201
202


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
203
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
204
    ## Serialize model
205
206
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
207

208
        serialize_vllm_model(get_torch_model(vllm_model),
209
                            TensorizerConfig(tensorizer_uri=model_path))
210

211
212
213
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
214

215
216
217
    ## Start OpenAI API server
    openai_args = [
        "--model", model_ref, "--dtype", "float16", "--load-format",
218
        "tensorizer", "--model-loader-extra-config",
219
        json.dumps(model_loader_extra_config),
220
221
    ]

222
223
    with RemoteOpenAIServer(openai_args) as server:
        print("Server ready.")
224

225
226
227
228
229
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
                                            prompt="Hello, my name is",
                                            max_tokens=5,
                                            temperature=0.0)
230

231
232
233
234
235
236
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
237
238
239
240


def test_raise_value_error_on_invalid_load_format(vllm_runner):
    with pytest.raises(ValueError):
241
242
243
244
        vllm_runner(
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
245
246


247
248
249
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
250
251
252
253
254
255
256
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
257
258
259
260
261
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
262
            tensor_parallel_size=2,
263
            disable_custom_all_reduce=True,
264
        )
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
                                                                    tmp_path):
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
    base_model = vllm_runner(
        model_ref,
        disable_custom_all_reduce=True,
        enforce_eager=True,
    )
    outputs = base_model.generate(prompts, sampling_params)

    base_model.model.llm_engine.model_executor.shutdown()
    del base_model
    cleanup()

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
                model=model_ref,
                tensor_parallel_size=2,
                disable_custom_all_reduce=True,
                enforce_eager=True,
            ),
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
    cleanup()

    loaded_vllm_model = vllm_runner(
        model_ref,
        tensor_parallel_size=2,
        load_format="tensorizer",
        disable_custom_all_reduce=True,
        enforce_eager=True,
        model_loader_extra_config=tensorizer_config)

    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    assert outputs == deserialized_outputs

317
318
319
320
321
322

def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

323
324
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
325
        serialize_vllm_model(get_torch_model(vllm_model), config)
326

327
        assert is_vllm_tensorized(config)
328

329
330
331
332
    with vllm_runner(model_ref,
                    load_format="tensorizer",
                    model_loader_extra_config=config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
333

334
        assert outputs == deserialized_outputs