save_sharded_state.py 3.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
    --model /path/to/load \
    --quantization deepspeedfp \
    --tensor-parallel-size 8 \
    --output /path/to/save

Then, the model can be loaded with

llm = LLM(
    model="/path/to/save",
    load_format="sharded_state",
    quantization="deepspeedfp",
    tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs
30
from vllm.utils import FlexibleArgumentParser
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def parse_args():
    parser = FlexibleArgumentParser()
    EngineArgs.add_cli_args(parser)
    parser.add_argument("--output",
                        "-o",
                        required=True,
                        type=str,
                        help="path to output checkpoint")
    parser.add_argument("--file-pattern",
                        type=str,
                        help="string pattern of saved filenames")
    parser.add_argument("--max-file-size",
                        type=str,
                        default=5 * 1024**3,
                        help="max size (in bytes) of each safetensors file")
    return parser.parse_args()
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def main(args):
    engine_args = EngineArgs.from_cli_args(args)
    if engine_args.enable_lora:
        raise ValueError("Saving with enable_lora=True is not supported!")
    model_path = engine_args.model
    if not Path(model_path).is_dir():
        raise ValueError("model path must be a local directory")
    # Create LLM instance from arguments
    llm = LLM(**dataclasses.asdict(engine_args))
    # Prepare output directory
    Path(args.output).mkdir(exist_ok=True)
    # Dump worker states to output directory
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    # Check which engine version is being used
    is_v1_engine = hasattr(llm.llm_engine, "engine_core")

    if is_v1_engine:
        # For V1 engine, we need to use engine_core.save_sharded_state
        print("Using V1 engine save path")
        llm.llm_engine.engine_core.save_sharded_state(
            path=args.output,
            pattern=args.file_pattern,
            max_size=args.max_file_size)
    else:
        # For V0 engine
        print("Using V0 engine save path")
        model_executor = llm.llm_engine.model_executor
        model_executor.save_sharded_state(path=args.output,
                                          pattern=args.file_pattern,
                                          max_size=args.max_file_size)

82
83
84
85
86
87
88
89
90
91
92
    # Copy metadata files to output directory
    for file in os.listdir(model_path):
        if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
            if os.path.isdir(os.path.join(model_path, file)):
                shutil.copytree(os.path.join(model_path, file),
                                os.path.join(args.output, file))
            else:
                shutil.copy(os.path.join(model_path, file), args.output)


if __name__ == "__main__":
93
    args = parse_args()
94
    main(args)