test_tensorizer.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import gc
import subprocess
from unittest.mock import MagicMock, patch

import pytest
import torch

from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
from vllm.config import TensorizerConfig
from vllm.model_executor.tensorizer_loader import (
    EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
    load_with_tensorizer, open_stream)

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

model_ref = "facebook/opt-125m"


def is_curl_installed():
    try:
        subprocess.check_call(['curl', '--version'])
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


@pytest.fixture(autouse=True)
def tensorizer_config():
    config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
    return config


@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
    mock_linear_method = MagicMock()
    mock_agent_instance = mock_agent.return_value
    mock_agent_instance.deserialize.return_value = MagicMock()

    result = load_with_tensorizer(tensorizer_config,
                                  linear_method=mock_linear_method)

    mock_agent.assert_called_once_with(tensorizer_config,
                                       linear_method=mock_linear_method)
    mock_agent_instance.deserialize.assert_called_once()
    assert result == mock_agent_instance.deserialize.return_value


def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
    tensorizer_config.vllm_tensorized = True

    result = is_vllm_serialized_tensorizer(tensorizer_config)

    assert result is True


def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
    tensorizer_config.vllm_tensorized = False

    result = is_vllm_serialized_tensorizer(tensorizer_config)

    assert result is False


def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
    vllm_model = vllm_runner(model_ref)
    model_path = tmp_path / (model_ref + ".tensors")
    outputs = vllm_model.generate(prompts, sampling_params)
    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
             model_runner.model)
    with open_stream(model_path, "wb+") as stream:
        serializer = TensorSerializer(stream)
        serializer.write_module(model)
    del vllm_model, model
    gc.collect()
    torch.cuda.empty_cache()
    loaded_vllm_model = vllm_runner(model_ref,
                                    load_format="tensorizer",
                                    tensorizer_uri=model_path,
                                    num_readers=1,
                                    vllm_tensorized=True)
    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    # Assumes SamplingParams being seeded ensures the outputs are deterministic
    assert outputs == deserialized_outputs


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
    model_ref = "EleutherAI/pythia-1.4b"
    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

    loaded_hf_model = vllm_runner(
        model_ref,
        tensorizer_uri=tensorized_path,
        load_format="tensorizer",
        num_readers=1,
        vllm_tensorized=False,
        s3_endpoint="object.ord1.coreweave.com",
    )

    deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)

    assert deserialized_outputs


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
        vllm_runner, tmp_path):
    vllm_model = vllm_runner(model_ref)
    model_path = tmp_path / (model_ref + ".tensors")
    key_path = tmp_path / (model_ref + ".key")
    outputs = vllm_model.generate(prompts, sampling_params)
    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
             model_runner.model)

    encryption_params = EncryptionParams.random()
    with open_stream(model_path, "wb+") as stream:
        serializer = TensorSerializer(stream, encryption=encryption_params)
        serializer.write_module(model)
    with open_stream(key_path, "wb+") as stream:
        stream.write(encryption_params.key)
    del vllm_model, model
    gc.collect()
    torch.cuda.empty_cache()
    loaded_vllm_model = vllm_runner(model_ref,
                                    tensorizer_uri=model_path,
                                    load_format="tensorizer",
                                    encryption_keyfile=key_path,
                                    num_readers=1,
                                    vllm_tensorized=True)

    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

    # Assumes SamplingParams being seeded ensures the outputs are deterministic
    assert outputs == deserialized_outputs


def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
                                                tmp_path):
    hf_model = hf_runner(model_ref)
    model_path = tmp_path / (model_ref + ".tensors")
    max_tokens = 50
    outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
    with open_stream(model_path, "wb+") as stream:
        serializer = TensorSerializer(stream)
        serializer.write_module(hf_model.model)
    del hf_model
    gc.collect()
    torch.cuda.empty_cache()
    loaded_hf_model = vllm_runner(model_ref,
                                  tensorizer_uri=model_path,
                                  load_format="tensorizer",
                                  num_readers=1,
                                  vllm_tensorized=False)

    deserialized_outputs = loaded_hf_model.generate_greedy(
        prompts, max_tokens=max_tokens)

    assert outputs == deserialized_outputs


def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
    from huggingface_hub import snapshot_download

    from examples.multilora_inference import (create_test_prompts,
                                              process_requests)

    model_ref = "meta-llama/Llama-2-7b-hf"
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)

    # Serialize model before deserializing and binding LoRA adapters
    vllm_model = vllm_runner(model_ref, )
    model_path = tmp_path / (model_ref + ".tensors")
    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
             model_runner.model)
    with open_stream(model_path, "wb+") as stream:
        serializer = TensorSerializer(stream)
        serializer.write_module(model)
    del vllm_model, model
    gc.collect()
    torch.cuda.empty_cache()
    loaded_vllm_model = vllm_runner(
        model_ref,
        tensorizer_uri=model_path,
        load_format="tensorizer",
        num_readers=1,
        vllm_tensorized=True,
        enable_lora=True,
        max_loras=1,
        max_lora_rank=8,
        max_cpu_loras=2,
        max_num_seqs=50,
        max_model_len=1000,
    )
    process_requests(loaded_vllm_model.model.llm_engine, test_prompts)

    assert loaded_vllm_model


def test_load_without_tensorizer_load_format(vllm_runner):
    with pytest.raises(ValueError):
        vllm_runner(model_ref, tensorizer_uri="test")


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path):
    # Test serialize command
    serialize_args = [
        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
        model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
        tmp_path, "--suffix", "tests"
    ]
    result = subprocess.run(serialize_args, capture_output=True, text=True)
    print(result.stdout)  # Print the output of the serialize command

    assert result.returncode == 0, (f"Serialize command failed with output:"
                                    f"\n{result.stdout}\n{result.stderr}")

    path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"

    # Test deserialize command
    deserialize_args = [
        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
        model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
        path_to_tensors
    ]
    result = subprocess.run(deserialize_args, capture_output=True, text=True)
    assert result.returncode == 0, (f"Deserialize command failed with output:"
                                    f"\n{result.stdout}\n{result.stderr}")


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(tmp_path):
    ## Serialize model
    serialize_args = [
        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
        model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
        tmp_path, "--suffix", "tests"
    ]
    result = subprocess.run(serialize_args, capture_output=True, text=True)
    print(result.stdout)  # Print the output of the serialize command

    assert result.returncode == 0, (f"Serialize command failed with output:"
                                    f"\n{result.stdout}\n{result.stderr}")

    path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"

    ## Start OpenAI API server
    openai_args = [
        "--model", model_ref, "--dtype", "float16", "--load-format",
        "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
        "--port", "8000"
    ]

    server = ServerRunner.remote(openai_args)

    print("Server ready.")
    assert server.ready.remote()


def test_raise_value_error_on_invalid_load_format(vllm_runner):
    with pytest.raises(ValueError):
        vllm_runner(model_ref,
                    load_format="safetensors",
                    tensorizer_uri="test")


def test_tensorizer_with_tp(vllm_runner):
    with pytest.raises(ValueError):
        model_ref = "EleutherAI/pythia-1.4b"
        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"

        vllm_runner(
            model_ref,
            tensorizer_uri=tensorized_path,
            load_format="tensorizer",
            num_readers=1,
            vllm_tensorized=False,
            s3_endpoint="object.ord1.coreweave.com",
            tensor_parallel_size=2,
        )


@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorizer_warn_quant(tmp_path):
    model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
    serialize_args = [
        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
        model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
        "serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
    ]
    result = subprocess.run(serialize_args, capture_output=True, text=True)
    assert 'PerformanceWarning' in result.stderr