"docs/source/models/supported_models.md" did not exist on "2b5bf20988edaab21621b78a9eb589edc93f2763"
test_tensorizer.py 11.7 KB
Newer Older
1
2
import json
import os
3
import pathlib
4
5
6
import subprocess
from unittest.mock import MagicMock, patch

7
import openai
8
import pytest
9
10
import torch
from tensorizer import EncryptionParams
11
12

from vllm import SamplingParams
13
from vllm.engine.arg_utils import EngineArgs
14
15
16
17
18
19
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
20
21
                                                         serialize_vllm_model,
                                                         tensorize_vllm_model)
22

23
from ..conftest import VllmRunner, cleanup
24
from ..utils import RemoteOpenAIServer
25

26
27
28
# yapf conflicts with isort for this docstring


29

30
31
32
33
34
35
36
37
38
39
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
40
41
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
42
43
44
45
46
47
48
49

def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get_torch_model(vllm_runner: VllmRunner):
    return vllm_runner \
            .model \
            .llm_engine \
            .model_executor \
            .driver_worker \
            .model_runner \
            .model

def write_keyfile(keyfile_path: str):
    encryption_params = EncryptionParams.random()
    pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
    with open(keyfile_path, 'wb') as f:
        f.write(encryption_params.key)
64
65
66



67
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
68
69
70
71
72
73
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
74
                                  quant_method=mock_linear_method)
75
76

    mock_agent.assert_called_once_with(tensorizer_config,
77
                                       quant_method=mock_linear_method)
78
79
80
81
82
83
84
85
86
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

87
    with vllm_runner(model_ref,
88
89
90
91
92
                                  load_format="tensorizer",
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=tensorized_path,
                                      num_readers=1,
                                      s3_endpoint="object.ord1.coreweave.com",
93
                                  )) as loaded_hf_model:
94

95
        deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
96

97
        assert deserialized_outputs
98
99
100
101
102


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
103
    cleanup()
104
105
106
    with vllm_runner(model_ref) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
        key_path = tmp_path / (model_ref + ".key")
107
108
        write_keyfile(key_path)

109
        outputs = vllm_model.generate(prompts, sampling_params)
110

111
112
113
114
115
116
117
        config_for_serializing = TensorizerConfig(
            tensorizer_uri=model_path,
            encryption_keyfile=key_path
        )
        serialize_vllm_model(get_torch_model(vllm_model),
                            config_for_serializing)

118
119
120
121

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

122
    with vllm_runner(
123
124
        model_ref,
        load_format="tensorizer",
125
        model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
126

127
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
128

129
        assert outputs == deserialized_outputs
130
131
132
133


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
134
135
136
137
138
139
140
141
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

142
    with vllm_runner(model_ref,
143
                                  load_format="tensorizer",
144
145
146
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=model_path,
                                      num_readers=1,
147
                                  )) as loaded_hf_model:
148

149
150
        deserialized_outputs = loaded_hf_model.generate_greedy(
            prompts, max_tokens=max_tokens)
151

152
        assert outputs == deserialized_outputs
153
154
155
156
157
158
159
160
161
162
163
164
165


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
166
167
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
168

169
        serialize_vllm_model(get_torch_model(vllm_model),
170
                            TensorizerConfig(tensorizer_uri=model_path))
171

172
    with vllm_runner(
173
174
        model_ref,
        load_format="tensorizer",
175
176
177
178
        model_loader_extra_config=TensorizerConfig(
            tensorizer_uri=model_path,
            num_readers=1,
        ),
179
180
181
182
183
184
        enable_lora=True,
        max_loras=1,
        max_lora_rank=8,
        max_cpu_loras=2,
        max_num_seqs=50,
        max_model_len=1000,
185
186
    ) as loaded_vllm_model:
        process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
187

188
        assert loaded_vllm_model
189
190
191
192


def test_load_without_tensorizer_load_format(vllm_runner):
    with pytest.raises(ValueError):
193
194
195
        vllm_runner(
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
196
197
198


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
199
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
200
    ## Serialize model
201
202
    with vllm_runner(model_ref, ) as vllm_model:
        model_path = tmp_path / (model_ref + ".tensors")
203

204
        serialize_vllm_model(get_torch_model(vllm_model),
205
                            TensorizerConfig(tensorizer_uri=model_path))
206

207
208
209
        model_loader_extra_config = {
            "tensorizer_uri": str(model_path),
        }
210

211
212
    ## Start OpenAI API server
    openai_args = [
213
        "--dtype", "float16", "--load-format",
214
        "tensorizer", "--model-loader-extra-config",
215
        json.dumps(model_loader_extra_config),
216
217
    ]

218
    with RemoteOpenAIServer(model_ref, openai_args) as server:
219
        print("Server ready.")
220

221
222
223
224
225
        client = server.get_client()
        completion = client.completions.create(model=model_ref,
                                            prompt="Hello, my name is",
                                            max_tokens=5,
                                            temperature=0.0)
226

227
228
229
230
231
232
        assert completion.id is not None
        assert len(completion.choices) == 1
        assert len(completion.choices[0].text) >= 5
        assert completion.choices[0].finish_reason == "length"
        assert completion.usage == openai.types.CompletionUsage(
            completion_tokens=5, prompt_tokens=6, total_tokens=11)
233
234
235
236


def test_raise_value_error_on_invalid_load_format(vllm_runner):
    with pytest.raises(ValueError):
237
238
239
240
        vllm_runner(
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
241
242


243
244
245
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
246
247
248
249
250
251
252
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
253
254
255
256
257
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
258
            tensor_parallel_size=2,
259
            disable_custom_all_reduce=True,
260
        )
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
                                                                    tmp_path):
    model_ref = "EleutherAI/pythia-1.4b"
    # record outputs from un-sharded un-tensorized model
    base_model = vllm_runner(
        model_ref,
        disable_custom_all_reduce=True,
        enforce_eager=True,
    )
    outputs = base_model.generate(prompts, sampling_params)

    base_model.model.llm_engine.model_executor.shutdown()
    del base_model
    cleanup()

    # load model with two shards and serialize with encryption
    model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
    key_path = tmp_path / (model_ref + ".key")

    tensorizer_config = TensorizerConfig(
        tensorizer_uri=model_path,
        encryption_keyfile=key_path,
    )

    tensorize_vllm_model(
        engine_args=EngineArgs(
                model=model_ref,
                tensor_parallel_size=2,
                disable_custom_all_reduce=True,
                enforce_eager=True,
            ),
        tensorizer_config=tensorizer_config,
    )
    assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
    assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
    cleanup()

    loaded_vllm_model = vllm_runner(
        model_ref,
        tensor_parallel_size=2,
        load_format="tensorizer",
        disable_custom_all_reduce=True,
        enforce_eager=True,
        model_loader_extra_config=tensorizer_config)

    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    assert outputs == deserialized_outputs

313
314

def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
315
    cleanup()
316
317
318
319
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

320
321
    with vllm_runner(model_ref) as vllm_model:
        outputs = vllm_model.generate(prompts, sampling_params)
322
        serialize_vllm_model(get_torch_model(vllm_model), config)
323

324
        assert is_vllm_tensorized(config)
325

326
327
328
329
    with vllm_runner(model_ref,
                    load_format="tensorizer",
                    model_loader_extra_config=config) as loaded_vllm_model:
        deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
330

331
        assert outputs == deserialized_outputs