tensorize_vllm_model.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import argparse
import dataclasses
6
import json
7
import logging
8
9
10
import os
import uuid

11
from vllm import LLM, SamplingParams
12
from vllm.engine.arg_utils import EngineArgs
13
14
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import (
15
16
17
18
    TensorizerArgs,
    TensorizerConfig,
    tensorize_lora_adapter,
    tensorize_vllm_model,
19
    tensorizer_kwargs_arg,
20
)
21
from vllm.utils import FlexibleArgumentParser
22

23
24
25
logger = logging.getLogger()


26
27
28
29
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and 
30
31
32
33
deserialize vLLM models. These models can be loaded using tensorizer 
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
or locally. Tensor encryption and decryption is also supported, although 
libsodium must be installed to use it. Install vllm with tensorizer support 
34
35
using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
https://github.com/coreweave/tensorizer
36

37
38
To serialize a model, install vLLM from source, then run something 
like this from the root level of this repository:
39

40
python examples/others/tensorize_vllm_model.py \
41
   --model facebook/opt-125m \
42
   serialize \
43
44
   --serialized-directory s3://my-bucket \
   --suffix v1
45
46
   
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
47
48
and saves it to your S3 bucket. A local directory can also be used. This
assumes your S3 credentials are specified as environment variables
49
50
51
52
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and 
`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide 
`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` 
as CLI args to this script.
53
54
55
56

You can also encrypt the model weights with a randomly-generated key by 
providing a `--keyfile` argument.

57
58
To deserialize a model, you can run something like this from the root 
level of this repository:
59

60
python examples/others/tensorize_vllm_model.py \
61
62
63
   --model EleutherAI/gpt-j-6B \
   --dtype float16 \
   deserialize \
64
   --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
65
66
67
68
69
70

Which downloads the model tensors from your S3 bucket and deserializes them.

You can also provide a `--keyfile` argument to decrypt the model weights if 
they were serialized with encryption.

71
72
73
74
75
76
To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors

77
For more information on the available arguments for serializing, run 
78
`python -m examples.others.tensorize_vllm_model serialize --help`.
79
80
81

Or for deserializing:

82
`python examples/others/tensorize_vllm_model.py deserialize --help`.
83

84
85
Once a model is serialized, tensorizer can be invoked with the `LLM` class 
directly to load models:
86
87
88

    llm = LLM(model="facebook/opt-125m",
              load_format="tensorizer",
89
90
91
92
93
94
95
96
97
98
99
100
101
102
              model_loader_extra_config=TensorizerConfig(
                    tensorizer_uri = path_to_tensors,
                    num_readers=3,
                    )
              )
            
A serialized model can be used during model loading for the vLLM OpenAI
inference server. `model_loader_extra_config` is exposed as the CLI arg
`--model-loader-extra-config`, and accepts a JSON string literal of the
TensorizerConfig arguments desired.

In order to see all of the available arguments usable to configure 
loading with tensorizer that are given to `TensorizerConfig`, run:

103
`python examples/others/tensorize_vllm_model.py deserialize --help`
104
105
106
107

under the `tensorizer options` section. These can also be used for
deserialization in this example script, although `--tensorizer-uri` and
`--path-to-tensors` are functionally the same in this case.
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter
can be serialized directly with the path to the LoRA adapter on HF Hub and
a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter
will serialize the LoRA adapter artifacts to `--serialized-directory`.

You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring 
the LoRA artifacts are in your model artifacts directory and specifying 
`--enable-lora`. For instance:

```
vllm serve <model_path> \
    --load-format tensorizer \
    --model-loader-extra-config '{"tensorizer_uri": "<model_path>.tensors"}' \
    --enable-lora
```
124
125
126
"""


127
def get_parser():
128
    parser = FlexibleArgumentParser(
129
130
131
132
133
134
135
        description="An example script that can be used to serialize and "
        "deserialize vLLM models. These models "
        "can be loaded using tensorizer directly to the GPU "
        "extremely quickly. Tensor encryption and decryption is "
        "also supported, although libsodium must be installed to "
        "use it.")
    parser = EngineArgs.add_cli_args(parser)
136
137
138
139
140
141
142

    parser.add_argument(
        "--lora-path",
        type=str,
        required=False,
        help="Path to a LoRA adapter to "
        "serialize along with model tensors. This can then be deserialized "
143
144
145
146
        "along with the model by instantiating a TensorizerConfig object, "
        "creating a dict from it with TensorizerConfig.to_serializable(), "
        "and passing it to LoRARequest's initializer with the kwarg "
        "tensorizer_config_dict."
147
148
    )

149
    subparsers = parser.add_subparsers(dest='command', required=True)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    serialize_parser = subparsers.add_parser(
        'serialize', help="Serialize a model to `--serialized-directory`")

    serialize_parser.add_argument(
        "--suffix",
        type=str,
        required=False,
        help=(
            "The suffix to append to the serialized model directory, which is "
            "used to construct the location of the serialized model tensors, "
            "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
            "`--suffix` is `v1`, the serialized model tensors will be "
            "saved to "
            "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
            "If none is provided, a random UUID will be used."))
    serialize_parser.add_argument(
        "--serialized-directory",
        type=str,
        required=True,
        help="The directory to serialize the model to. "
        "This can be a local directory or S3 URI. The path to where the "
        "tensors are saved is a combination of the supplied `dir` and model "
        "reference ID. For instance, if `dir` is the serialized directory, "
        "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
        "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
        "where `suffix` is given by `--suffix` or a random UUID if not "
        "provided.")

179
180
181
182
183
184
185
186
    serialize_parser.add_argument(
        "--serialization-kwargs",
        type=tensorizer_kwargs_arg,
        required=False,
        help=("A JSON string containing additional keyword arguments to "
              "pass to Tensorizer's TensorSerializer during "
              "serialization."))

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    serialize_parser.add_argument(
        "--keyfile",
        type=str,
        required=False,
        help=("Encrypt the model weights with a randomly-generated binary key,"
              " and save the key at this path"))

    deserialize_parser = subparsers.add_parser(
        'deserialize',
        help=("Deserialize a model from `--path-to-tensors`"
              " to verify it can be loaded and used."))

    deserialize_parser.add_argument(
        "--path-to-tensors",
        type=str,
202
        required=False,
203
204
        help="The local path or S3 URI to the model tensors to deserialize. ")

205
206
207
208
209
210
211
212
    deserialize_parser.add_argument(
        "--serialized-directory",
        type=str,
        required=False,
        help="Directory with model artifacts for loading. Assumes a "
             "model.tensors file exists therein. Can supersede "
             "--path-to-tensors.")

213
214
215
216
217
218
219
    deserialize_parser.add_argument(
        "--keyfile",
        type=str,
        required=False,
        help=("Path to a binary key to use to decrypt the model weights,"
              " if the model was serialized with encryption"))

220
221
222
223
224
225
226
    deserialize_parser.add_argument(
        "--deserialization-kwargs",
        type=tensorizer_kwargs_arg,
        required=False,
        help=("A JSON string containing additional keyword arguments to "
              "pass to Tensorizer's `TensorDeserializer` during "
              "deserialization."))
227

228
    TensorizerArgs.add_cli_args(deserialize_parser)
229

230
    return parser
231

232
233
234
235
236
237
238
239
240
def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
                                              cfg: TensorizerConfig):
    for k, v in extra_cfg.items():
        if hasattr(cfg, k):
            setattr(cfg, k, v)
            logger.info(
                "Updating TensorizerConfig with %s from "
                "--model-loader-extra-config provided", k
            )
241

Reid's avatar
Reid committed
242
def deserialize(args, tensorizer_config):
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    if args.lora_path:
        tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
        llm = LLM(model=args.model,
                  load_format="tensorizer",
                  tensor_parallel_size=args.tensor_parallel_size,
                  model_loader_extra_config=tensorizer_config,
                  enable_lora=True,
        )
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=256,
            stop=["[/assistant]"]
        )

        # Truncating this as the extra text isn't necessary
        prompts = [
            "[user] Write a SQL query to answer the question based on ..."
        ]

        # Test LoRA load
        print(
            llm.generate(
            prompts,
            sampling_params,
            lora_request=LoRARequest("sql-lora",
                                     1,
                                     args.lora_path,
270
271
                                     tensorizer_config_dict = tensorizer_config
                                     .to_serializable())
272
273
274
275
276
277
278
279
            )
        )
    else:
        llm = LLM(model=args.model,
                  load_format="tensorizer",
                  tensor_parallel_size=args.tensor_parallel_size,
                  model_loader_extra_config=tensorizer_config
        )
280
    return llm
281
282


Reid's avatar
Reid committed
283
def main():
284
285
    parser = get_parser()
    args = parser.parse_args()
286

287
288
289
290
291
292
    s3_access_key_id = (getattr(args, 's3_access_key_id', None)
                        or os.environ.get("S3_ACCESS_KEY_ID", None))
    s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
                            or os.environ.get("S3_SECRET_ACCESS_KEY", None))
    s3_endpoint = (getattr(args, 's3_endpoint', None)
                or os.environ.get("S3_ENDPOINT_URL", None))
293

294
295
296
297
298
    credentials = {
        "s3_access_key_id": s3_access_key_id,
        "s3_secret_access_key": s3_secret_access_key,
        "s3_endpoint": s3_endpoint
    }
299

300
    model_ref = args.model
301

302
303
304
305
    if args.command == "serialize" or args.command == "deserialize":
        keyfile = args.keyfile
    else:
        keyfile = None
306

307
    extra_config = {}
308
    if args.model_loader_extra_config:
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        extra_config = json.loads(args.model_loader_extra_config)


    tensorizer_dir = (args.serialized_directory or
                      extra_config.get("tensorizer_dir"))
    tensorizer_uri = (getattr(args, "path_to_tensors", None)
                      or extra_config.get("tensorizer_uri"))

    if tensorizer_dir and tensorizer_uri:
        parser.error("--serialized-directory and --path-to-tensors "
                     "cannot both be provided")

    if not tensorizer_dir and not tensorizer_uri:
        parser.error("Either --serialized-directory or --path-to-tensors "
                     "must be provided")

325

326
327
328
    if args.command == "serialize":
        eng_args_dict = {f.name: getattr(args, f.name) for f in
                        dataclasses.fields(EngineArgs)}
329

330
331
332
        engine_args = EngineArgs.from_cli_args(
            argparse.Namespace(**eng_args_dict)
        )
333

334
        input_dir = tensorizer_dir.rstrip('/')
335
336
337
338
339
340
        suffix = args.suffix if args.suffix else uuid.uuid4().hex
        base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
        if engine_args.tensor_parallel_size > 1:
            model_path = f"{base_path}/model-rank-%03d.tensors"
        else:
            model_path = f"{base_path}/model.tensors"
341
342

        tensorizer_config = TensorizerConfig(
343
344
            tensorizer_uri=model_path,
            encryption_keyfile=keyfile,
345
346
347
            serialization_kwargs=args.serialization_kwargs or {},
            **credentials
        )
348

349
350
351
352
        if args.lora_path:
            tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
            tensorize_lora_adapter(args.lora_path, tensorizer_config)

353
354
        merge_extra_config_with_tensorizer_config(extra_config,
                                                  tensorizer_config)
355
356
357
        tensorize_vllm_model(engine_args, tensorizer_config)

    elif args.command == "deserialize":
358
359
360
361
362
363
364
365
366
367
        tensorizer_config = TensorizerConfig(
            tensorizer_uri=args.path_to_tensors,
            tensorizer_dir=args.serialized_directory,
            encryption_keyfile=keyfile,
            deserialization_kwargs=args.deserialization_kwargs or {},
            **credentials
        )

        merge_extra_config_with_tensorizer_config(extra_config,
                                                  tensorizer_config)
Reid's avatar
Reid committed
368
        deserialize(args, tensorizer_config)
369
370
    else:
        raise ValueError("Either serialize or deserialize must be specified.")
Reid's avatar
Reid committed
371
372
373
374


if __name__ == "__main__":
    main()