test_tensorizer.py 9.45 KB
Newer Older
1
import gc
2
3
import json
import os
4
5
6
import subprocess
from unittest.mock import MagicMock, patch

7
import openai
8
import pytest
9
import ray
10
11
12
import torch

from vllm import SamplingParams
13
14
15
16
17
18
19
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
                                                         TensorSerializer,
                                                         is_vllm_tensorized,
                                                         load_with_tensorizer,
                                                         open_stream,
                                                         serialize_vllm_model)
20

21
22
from ..utils import ServerRunner

23
24
25
# yapf conflicts with isort for this docstring


26
27
28
29
30
31
32
33
34
35
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"
36
37
tensorize_model_for_testing_script = os.path.join(
    os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
38
39
40
41
42
43
44
45
46
47
48
49


def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


@pytest.fixture(autouse=True)
def tensorizer_config():
50
    config = TensorizerConfig(tensorizer_uri="vllm")
51
52
53
    return config


54
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
55
56
57
58
59
60
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
61
                                  quant_method=mock_linear_method)
62
63

    mock_agent.assert_called_once_with(tensorizer_config,
64
                                       quant_method=mock_linear_method)
65
66
67
68
69
70
71
72
73
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

74
75
76
77
78
79
80
    loaded_hf_model = vllm_runner(model_ref,
                                  load_format="tensorizer",
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=tensorized_path,
                                      num_readers=1,
                                      s3_endpoint="object.ord1.coreweave.com",
                                  ))
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)

    assert deserialized_outputs


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
    vllm_model = vllm_runner(model_ref)
    model_path = tmp_path / (model_ref + ".tensors")
    key_path = tmp_path / (model_ref + ".key")
    outputs = vllm_model.generate(prompts, sampling_params)

95
96
97
98
99
100
    config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
    serialize_vllm_model(vllm_model.model.llm_engine,
                         config_for_serializing,
                         encryption_key_path=key_path)

    del vllm_model
101
102
    gc.collect()
    torch.cuda.empty_cache()
103
104
105
106
107
108
109
110

    config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
                                                encryption_keyfile=key_path)

    loaded_vllm_model = vllm_runner(
        model_ref,
        load_format="tensorizer",
        model_loader_extra_config=config_for_deserializing)
111
112
113
114
115
116
117
118

    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    assert outputs == deserialized_outputs


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
119
120
121
122
123
124
125
126
    with hf_runner(model_ref) as hf_model:
        model_path = tmp_path / (model_ref + ".tensors")
        max_tokens = 50
        outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
        with open_stream(model_path, "wb+") as stream:
            serializer = TensorSerializer(stream)
            serializer.write_module(hf_model.model)

127
128
    loaded_hf_model = vllm_runner(model_ref,
                                  load_format="tensorizer",
129
130
131
                                  model_loader_extra_config=TensorizerConfig(
                                      tensorizer_uri=model_path,
                                      num_readers=1,
132
                                  ))
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    deserialized_outputs = loaded_hf_model.generate_greedy(
        prompts, max_tokens=max_tokens)

    assert outputs == deserialized_outputs


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
    vllm_model = vllm_runner(model_ref, )
    model_path = tmp_path / (model_ref + ".tensors")
153
154
155
156
157

    serialize_vllm_model(vllm_model.model.llm_engine,
                         TensorizerConfig(tensorizer_uri=model_path))

    del vllm_model
158
159
160
161
162
    gc.collect()
    torch.cuda.empty_cache()
    loaded_vllm_model = vllm_runner(
        model_ref,
        load_format="tensorizer",
163
164
165
166
        model_loader_extra_config=TensorizerConfig(
            tensorizer_uri=model_path,
            num_readers=1,
        ),
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        enable_lora=True,
        max_loras=1,
        max_lora_rank=8,
        max_cpu_loras=2,
        max_num_seqs=50,
        max_model_len=1000,
    )
    process_requests(loaded_vllm_model.model.llm_engine, test_prompts)

    assert loaded_vllm_model


def test_load_without_tensorizer_load_format(vllm_runner):
    with pytest.raises(ValueError):
181
182
183
        vllm_runner(
            model_ref,
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
184
185
186


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
187
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
188
    ## Serialize model
189
190
    vllm_model = vllm_runner(model_ref, )
    model_path = tmp_path / (model_ref + ".tensors")
191

192
193
    serialize_vllm_model(vllm_model.model.llm_engine,
                         TensorizerConfig(tensorizer_uri=model_path))
194

195
    model_loader_extra_config = {
196
        "tensorizer_uri": str(model_path),
197
    }
198

199
200
201
202
    del vllm_model
    gc.collect()
    torch.cuda.empty_cache()

203
204
205
    ## Start OpenAI API server
    openai_args = [
        "--model", model_ref, "--dtype", "float16", "--load-format",
206
207
        "tensorizer", "--model-loader-extra-config",
        json.dumps(model_loader_extra_config), "--port", "8000"
208
209
210
211
    ]

    server = ServerRunner.remote(openai_args)

212
    assert ray.get(server.ready.remote())
213
    print("Server ready.")
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    client = openai.OpenAI(
        base_url="http://localhost:8000/v1",
        api_key="token-abc123",
    )
    completion = client.completions.create(model=model_ref,
                                           prompt="Hello, my name is",
                                           max_tokens=5,
                                           temperature=0.0)

    assert completion.id is not None
    assert completion.choices is not None and len(completion.choices) == 1
    assert completion.choices[0].text is not None and len(
        completion.choices[0].text) >= 5
    assert completion.choices[0].finish_reason == "length"
    assert completion.usage == openai.types.CompletionUsage(
        completion_tokens=5, prompt_tokens=6, total_tokens=11)
231
232
233
234


def test_raise_value_error_on_invalid_load_format(vllm_runner):
    with pytest.raises(ValueError):
235
236
237
238
        vllm_runner(
            model_ref,
            load_format="safetensors",
            model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
239
240
241
242
243
244
245
246
247
248


def test_tensorizer_with_tp(vllm_runner):
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            load_format="tensorizer",
249
250
251
252
253
            model_loader_extra_config=TensorizerConfig(
                tensorizer_uri=tensorized_path,
                num_readers=1,
                s3_endpoint="object.ord1.coreweave.com",
            ),
254
255
            tensor_parallel_size=2,
        )
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
    model_ref = "facebook/opt-125m"
    model_path = tmp_path / (model_ref + ".tensors")
    config = TensorizerConfig(tensorizer_uri=str(model_path))

    vllm_model = vllm_runner(model_ref)
    outputs = vllm_model.generate(prompts, sampling_params)
    serialize_vllm_model(vllm_model.model.llm_engine, config)

    assert is_vllm_tensorized(config)
    del vllm_model
    gc.collect()
    torch.cuda.empty_cache()

    loaded_vllm_model = vllm_runner(model_ref,
                                    load_format="tensorizer",
                                    model_loader_extra_config=config)
    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    assert outputs == deserialized_outputs