test_tensorizer.py 11.8 KB
Newer Older
1
2
import json
import os
3
import pathlib
4
5
6
import subprocess
from unittest.mock import MagicMock, patch

7
import openai
8
import pytest
9
import ray
10
11
import torch
from tensorizer import EncryptionParams
12
13

from vllm import SamplingParams
14
from vllm.engine.arg_utils import EngineArgs
15
16
17
18
19
20
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
21
22
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
23

24
from ..conftest import VllmRunner, cleanup
25
from ..utils import RemoteOpenAIServer
26

27
28
29
# yapf conflicts with isort for this docstring


30

31
32
33
34
35
36
37
38
39
40
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
41
42
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
43
44
45
46
47
48
49
50
51


def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

52
53
54
55
56
57
58
59
60
61
62
63
64
65
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
            .model \
            .llm_engine \
            .model_executor \
            .driver_worker \
            .model_runner \
            .model

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
66
67
68

@pytest.fixture(autouse=True)
def tensorizer_config():
69
    config = TensorizerConfig(tensorizer_uri="vllm")
70
71
72
    return config


73
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
74
75
76
77
78
79
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
80
                                  quant_method=mock_linear_method)
81
82

    mock_agent.assert_called_once_with(tensorizer_config,
83
                                       quant_method=mock_linear_method)
84
85
86
87
88
89
90
91
92
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

93
    with vllm_runner(model_ref,
94
95
96
97
98
                                  load_format="tensorizer",
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=tensorized_path,
                                      num_readers=1,
                                      s3_endpoint="object.ord1.coreweave.com",
99
                                  )) as loaded_hf_model:
100

101
        deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
102

103
        assert deserialized_outputs
104
105
106
107
108


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
109
110
111
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
112
113
        write_keyfile(key_path)

114
        outputs = vllm_model.generate(prompts, sampling_params)
115

116
117
118
119
120
121
122
        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_vllm_model(get_torch_model(vllm_model),
                            config_for_serializing)

123
124
125
126

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

127
    with vllm_runner(
128
129
        model_ref,
        load_format="tensorizer",
130
        model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
131

132
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
133

134
        assert outputs == deserialized_outputs
135
136
137
138


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
139
140
141
142
143
144
145
146
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

147
    with vllm_runner(model_ref,
148
                                  load_format="tensorizer",
149
150
151
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=model_path,
                                      num_readers=1,
152
                                  )) as loaded_hf_model:
153

154
155
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
156

157
        assert outputs == deserialized_outputs
158
159
160
161
162
163
164
165
166
167
168
169
170


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
171
172
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
173

174
        serialize_vllm_model(get_torch_model(vllm_model),
175
                            TensorizerConfig(tensorizer_uri=model_path))
176

177
    with vllm_runner(
178
179
        model_ref,
        load_format="tensorizer",
180
181
182
183
        model_loader_extra_config=TensorizerConfig(
            tensorizer_uri=model_path,
            num_readers=1,
        ),
184
185
186
187
188
189
        enable_lora=True,
        max_loras=1,
        max_lora_rank=8,
        max_cpu_loras=2,
        max_num_seqs=50,
        max_model_len=1000,
190
191
    ) as loaded_vllm_model:
        process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
192

193
        assert loaded_vllm_model
194
195
196
197


def test_load_without_tensorizer_load_format(vllm_runner):
    with pytest.raises(ValueError):
198
199
200
        vllm_runner(
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
201
202
203


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
204
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
205
    ## Serialize model
206
207
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
208

209
        serialize_vllm_model(get_torch_model(vllm_model),
210
                            TensorizerConfig(tensorizer_uri=model_path))
211

212
213
214
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
215

216
217
218
    ## Start OpenAI API server
    openai_args = [
        "--model", model_ref, "--dtype", "float16", "--load-format",
219
        "tensorizer", "--model-loader-extra-config",
220
        json.dumps(model_loader_extra_config),
221
222
    ]

223
    server = RemoteOpenAIServer(openai_args)
224
    print("Server ready.")
225

226
    client = server.get_client()
227
228
229
230
231
232
    completion = client.completions.create(model=model_ref,
                                           prompt="Hello, my name is",
                                           max_tokens=5,
                                           temperature=0.0)

    assert completion.id is not None
233
234
    assert len(completion.choices) == 1
    assert len(completion.choices[0].text) >= 5
235
236
237
    assert completion.choices[0].finish_reason == "length"
    assert completion.usage == openai.types.CompletionUsage(
        completion_tokens=5, prompt_tokens=6, total_tokens=11)
238
239
240
241


def test_raise_value_error_on_invalid_load_format(vllm_runner):
    with pytest.raises(ValueError):
242
243
244
245
        vllm_runner(
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
246
247


248
249
250
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
251
252
253
254
255
256
257
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
258
259
260
261
262
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
263
            tensor_parallel_size=2,
264
            disable_custom_all_reduce=True,
265
        )
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
                                                                    tmp_path):
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
    base_model = vllm_runner(
        model_ref,
        disable_custom_all_reduce=True,
        enforce_eager=True,
    )
    outputs = base_model.generate(prompts, sampling_params)

    base_model.model.llm_engine.model_executor.shutdown()
    del base_model
    cleanup()
    ray.shutdown()

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
                model=model_ref,
                tensor_parallel_size=2,
                disable_custom_all_reduce=True,
                enforce_eager=True,
            ),
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
    cleanup()
    ray.shutdown()

    loaded_vllm_model = vllm_runner(
        model_ref,
        tensor_parallel_size=2,
        load_format="tensorizer",
        disable_custom_all_reduce=True,
        enforce_eager=True,
        model_loader_extra_config=tensorizer_config)

    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    assert outputs == deserialized_outputs

320
321
322
323
324
325

def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

326
327
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
328
        serialize_vllm_model(get_torch_model(vllm_model), config)
329

330
        assert is_vllm_tensorized(config)
331

332
333
334
335
    with vllm_runner(model_ref,
                    load_format="tensorizer",
                    model_loader_extra_config=config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
336

337
        assert outputs == deserialized_outputs