tensorize_vllm_model.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import logging
6
7
8
import os
import uuid

9
from vllm import LLM, SamplingParams
10
from vllm.engine.arg_utils import EngineArgs
11
12
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import (
13
14
15
16
    TensorizerArgs,
    TensorizerConfig,
    tensorize_lora_adapter,
    tensorize_vllm_model,
17
    tensorizer_kwargs_arg,
18
)
19
from vllm.utils import FlexibleArgumentParser
20

21
22
23
logger = logging.getLogger()


24
25
"""
tensorize_vllm_model.py is a script that can be used to serialize and 
26
27
28
29
deserialize vLLM models. These models can be loaded using tensorizer 
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
or locally. Tensor encryption and decryption is also supported, although 
libsodium must be installed to use it. Install vllm with tensorizer support 
30
31
using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
https://github.com/coreweave/tensorizer
32

33
34
To serialize a model, install vLLM from source, then run something 
like this from the root level of this repository:
35

36
python examples/others/tensorize_vllm_model.py \
37
   --model facebook/opt-125m \
38
   serialize \
39
40
   --serialized-directory s3://my-bucket \
   --suffix v1
41
42
   
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
43
44
and saves it to your S3 bucket. A local directory can also be used. This
assumes your S3 credentials are specified as environment variables
45
46
47
48
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and 
`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide 
`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` 
as CLI args to this script.
49
50
51
52

You can also encrypt the model weights with a randomly-generated key by 
providing a `--keyfile` argument.

53
54
To deserialize a model, you can run something like this from the root 
level of this repository:
55

56
python examples/others/tensorize_vllm_model.py \
57
58
59
   --model EleutherAI/gpt-j-6B \
   --dtype float16 \
   deserialize \
60
   --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
61
62
63
64
65
66

Which downloads the model tensors from your S3 bucket and deserializes them.

You can also provide a `--keyfile` argument to decrypt the model weights if 
they were serialized with encryption.

67
68
69
70
71
72
To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors

73
For more information on the available arguments for serializing, run 
74
`python -m examples.others.tensorize_vllm_model serialize --help`.
75
76
77

Or for deserializing:

78
`python examples/others/tensorize_vllm_model.py deserialize --help`.
79

80
81
Once a model is serialized, tensorizer can be invoked with the `LLM` class 
directly to load models:
82

83
84
85
86
87
88
89
90
```python
from vllm import LLM
llm = LLM(
    "s3://my-bucket/vllm/facebook/opt-125m/v1", 
    load_format="tensorizer"
)
```

91
92
            
A serialized model can be used during model loading for the vLLM OpenAI
93
94
95
96
97
98
inference server:

```
vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 \
    --load-format tensorizer
```
99
100
101
102

In order to see all of the available arguments usable to configure 
loading with tensorizer that are given to `TensorizerConfig`, run:

103
`python examples/others/tensorize_vllm_model.py deserialize --help`
104
105
106
107

under the `tensorizer options` section. These can also be used for
deserialization in this example script, although `--tensorizer-uri` and
`--path-to-tensors` are functionally the same in this case.
108
109
110
111
112
113
114
115
116
117
118

Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter
can be serialized directly with the path to the LoRA adapter on HF Hub and
a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter
will serialize the LoRA adapter artifacts to `--serialized-directory`.

You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring 
the LoRA artifacts are in your model artifacts directory and specifying 
`--enable-lora`. For instance:

```
119
vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 \
120
    --load-format tensorizer \
121
    --enable-lora 
122
```
123
124
125
"""


126
def get_parser():
127
    parser = FlexibleArgumentParser(
128
129
130
131
132
        description="An example script that can be used to serialize and "
        "deserialize vLLM models. These models "
        "can be loaded using tensorizer directly to the GPU "
        "extremely quickly. Tensor encryption and decryption is "
        "also supported, although libsodium must be installed to "
133
134
        "use it."
    )
135
    parser = EngineArgs.add_cli_args(parser)
136
137
138
139
140
141
142

    parser.add_argument(
        "--lora-path",
        type=str,
        required=False,
        help="Path to a LoRA adapter to "
        "serialize along with model tensors. This can then be deserialized "
143
144
145
        "along with the model by instantiating a TensorizerConfig object, "
        "creating a dict from it with TensorizerConfig.to_serializable(), "
        "and passing it to LoRARequest's initializer with the kwarg "
146
        "tensorizer_config_dict.",
147
148
    )

149
    subparsers = parser.add_subparsers(dest="command", required=True)
150
151

    serialize_parser = subparsers.add_parser(
152
153
        "serialize", help="Serialize a model to `--serialized-directory`"
    )
154
155
156
157
158
159
160
161
162
163
164
165

    serialize_parser.add_argument(
        "--suffix",
        type=str,
        required=False,
        help=(
            "The suffix to append to the serialized model directory, which is "
            "used to construct the location of the serialized model tensors, "
            "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
            "`--suffix` is `v1`, the serialized model tensors will be "
            "saved to "
            "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
166
167
168
            "If none is provided, a random UUID will be used."
        ),
    )
169
170
171
172
173
174
175
176
177
178
179
    serialize_parser.add_argument(
        "--serialized-directory",
        type=str,
        required=True,
        help="The directory to serialize the model to. "
        "This can be a local directory or S3 URI. The path to where the "
        "tensors are saved is a combination of the supplied `dir` and model "
        "reference ID. For instance, if `dir` is the serialized directory, "
        "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
        "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
        "where `suffix` is given by `--suffix` or a random UUID if not "
180
181
        "provided.",
    )
182

183
184
185
186
    serialize_parser.add_argument(
        "--serialization-kwargs",
        type=tensorizer_kwargs_arg,
        required=False,
187
188
189
190
191
192
        help=(
            "A JSON string containing additional keyword arguments to "
            "pass to Tensorizer's TensorSerializer during "
            "serialization."
        ),
    )
193

194
195
196
197
    serialize_parser.add_argument(
        "--keyfile",
        type=str,
        required=False,
198
199
200
201
202
        help=(
            "Encrypt the model weights with a randomly-generated binary key,"
            " and save the key at this path"
        ),
    )
203
204

    deserialize_parser = subparsers.add_parser(
205
206
207
208
209
210
        "deserialize",
        help=(
            "Deserialize a model from `--path-to-tensors`"
            " to verify it can be loaded and used."
        ),
    )
211
212
213
214

    deserialize_parser.add_argument(
        "--path-to-tensors",
        type=str,
215
        required=False,
216
217
        help="The local path or S3 URI to the model tensors to deserialize. ",
    )
218

219
220
221
222
223
    deserialize_parser.add_argument(
        "--serialized-directory",
        type=str,
        required=False,
        help="Directory with model artifacts for loading. Assumes a "
224
225
226
        "model.tensors file exists therein. Can supersede "
        "--path-to-tensors.",
    )
227

228
229
230
231
    deserialize_parser.add_argument(
        "--keyfile",
        type=str,
        required=False,
232
233
234
235
236
        help=(
            "Path to a binary key to use to decrypt the model weights,"
            " if the model was serialized with encryption"
        ),
    )
237

238
239
240
241
    deserialize_parser.add_argument(
        "--deserialization-kwargs",
        type=tensorizer_kwargs_arg,
        required=False,
242
243
244
245
246
247
        help=(
            "A JSON string containing additional keyword arguments to "
            "pass to Tensorizer's `TensorDeserializer` during "
            "deserialization."
        ),
    )
248

249
    TensorizerArgs.add_cli_args(deserialize_parser)
250

251
    return parser
252

253
254

def merge_extra_config_with_tensorizer_config(extra_cfg: dict, cfg: TensorizerConfig):
255
256
257
258
259
    for k, v in extra_cfg.items():
        if hasattr(cfg, k):
            setattr(cfg, k, v)
            logger.info(
                "Updating TensorizerConfig with %s from "
260
261
                "--model-loader-extra-config provided",
                k,
262
            )
263

264

Reid's avatar
Reid committed
265
def deserialize(args, tensorizer_config):
266
267
    if args.lora_path:
        tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
268
269
270
271
272
273
        llm = LLM(
            model=args.model,
            load_format="tensorizer",
            tensor_parallel_size=args.tensor_parallel_size,
            model_loader_extra_config=tensorizer_config,
            enable_lora=True,
274
275
        )
        sampling_params = SamplingParams(
276
            temperature=0, max_tokens=256, stop=["[/assistant]"]
277
278
279
        )

        # Truncating this as the extra text isn't necessary
280
        prompts = ["[user] Write a SQL query to answer the question based on ..."]
281
282
283
284

        # Test LoRA load
        print(
            llm.generate(
285
286
287
288
289
290
291
292
                prompts,
                sampling_params,
                lora_request=LoRARequest(
                    "sql-lora",
                    1,
                    args.lora_path,
                    tensorizer_config_dict=tensorizer_config.to_serializable(),
                ),
293
294
295
            )
        )
    else:
296
297
298
299
300
        llm = LLM(
            model=args.model,
            load_format="tensorizer",
            tensor_parallel_size=args.tensor_parallel_size,
            model_loader_extra_config=tensorizer_config,
301
        )
302
    return llm
303
304


Reid's avatar
Reid committed
305
def main():
306
307
    parser = get_parser()
    args = parser.parse_args()
308

309
310
311
312
313
314
315
316
317
    s3_access_key_id = getattr(args, "s3_access_key_id", None) or os.environ.get(
        "S3_ACCESS_KEY_ID", None
    )
    s3_secret_access_key = getattr(
        args, "s3_secret_access_key", None
    ) or os.environ.get("S3_SECRET_ACCESS_KEY", None)
    s3_endpoint = getattr(args, "s3_endpoint", None) or os.environ.get(
        "S3_ENDPOINT_URL", None
    )
318

319
320
321
    credentials = {
        "s3_access_key_id": s3_access_key_id,
        "s3_secret_access_key": s3_secret_access_key,
322
        "s3_endpoint": s3_endpoint,
323
    }
324

325
    model_ref = args.model
326

327
328
329
330
    if args.command == "serialize" or args.command == "deserialize":
        keyfile = args.keyfile
    else:
        keyfile = None
331

332
    extra_config = {}
333
    if args.model_loader_extra_config:
334
335
        extra_config = json.loads(args.model_loader_extra_config)

336
337
338
339
    tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir")
    tensorizer_uri = getattr(args, "path_to_tensors", None) or extra_config.get(
        "tensorizer_uri"
    )
340
341

    if tensorizer_dir and tensorizer_uri:
342
343
344
        parser.error(
            "--serialized-directory and --path-to-tensors cannot both be provided"
        )
345
346

    if not tensorizer_dir and not tensorizer_uri:
347
348
349
        parser.error(
            "Either --serialized-directory or --path-to-tensors must be provided"
        )
350

351
    if args.command == "serialize":
352
        engine_args = EngineArgs.from_cli_args(args)
353

354
        input_dir = tensorizer_dir.rstrip("/")
355
356
357
358
359
360
        suffix = args.suffix if args.suffix else uuid.uuid4().hex
        base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
        if engine_args.tensor_parallel_size > 1:
            model_path = f"{base_path}/model-rank-%03d.tensors"
        else:
            model_path = f"{base_path}/model.tensors"
361
362

        tensorizer_config = TensorizerConfig(
363
364
            tensorizer_uri=model_path,
            encryption_keyfile=keyfile,
365
            serialization_kwargs=args.serialization_kwargs or {},
366
            **credentials,
367
        )
368

369
370
371
372
        if args.lora_path:
            tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
            tensorize_lora_adapter(args.lora_path, tensorizer_config)

373
        merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
374
375
376
        tensorize_vllm_model(engine_args, tensorizer_config)

    elif args.command == "deserialize":
377
378
379
380
381
        tensorizer_config = TensorizerConfig(
            tensorizer_uri=args.path_to_tensors,
            tensorizer_dir=args.serialized_directory,
            encryption_keyfile=keyfile,
            deserialization_kwargs=args.deserialization_kwargs or {},
382
            **credentials,
383
384
        )

385
        merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
Reid's avatar
Reid committed
386
        deserialize(args, tensorizer_config)
387
388
    else:
        raise ValueError("Either serialize or deserialize must be specified.")
Reid's avatar
Reid committed
389
390
391
392


if __name__ == "__main__":
    main()