test_sharded_state_loader.py 5.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import fnmatch
5
import multiprocessing as mp
6
7
8
9
10
11
12
13
14
import os
import shutil
from tempfile import TemporaryDirectory

import pytest
import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
15
from vllm.model_executor.model_loader import ShardedStateLoader
16
17
18
19
20
21
22
23
24
25

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(
26
    temperature=0,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    max_tokens=256,
    ignore_eos=True,
)


def test_filter_subtensors():
    state_dict = {
        "a": torch.empty(2),
        "b": torch.empty((2, 4)),
        "c": torch.empty((2, 4, 8)),
    }
    state_dict.update({
        "x": state_dict["b"],
        "y": state_dict["c"][1, 2, :],
        "z": state_dict["c"][1, :, 4],
    })
    filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
    assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
    for key, tensor in filtered_state_dict.items():
46
        # NOTE: don't use `equal` here, as the tensor might contain NaNs
47
        assert tensor is state_dict[key]
48
49


50
@pytest.fixture(scope="module")
51
def llama_3p2_1b_files():
52
53
    input_dir = snapshot_download("meta-llama/Llama-3.2-1B-Instruct",
                                  ignore_patterns=["*.bin*", "original/*"])
54

55
    yield input_dir
56
57
58
59


def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
    llm_sharded_writer = LLM(model=input_dir, **kwargs)
60
61
    # Check which engine version is being used
    is_v1_engine = hasattr(llm_sharded_writer.llm_engine, "engine_core")
62
    # Dump worker states to output directory
63
64
65
66
67
68
69
70
71
72
    if is_v1_engine:
        # For V1 engine, we need to use engine_core.save_sharded_state
        print("Using V1 engine save path")
        llm_sharded_writer.llm_engine.engine_core.save_sharded_state(
            path=output_dir)
    else:
        # For V0 engine
        print("Using V0 engine save path")
        model_executor = llm_sharded_writer.llm_engine.model_executor
        model_executor.save_sharded_state(path=output_dir)
73

74
75
    # Copy metadata files to output directory
    for file in os.listdir(input_dir):
76
        if os.path.isdir(os.path.join(input_dir, file)):
77
78
79
80
            shutil.copytree(os.path.join(input_dir, file),
                            os.path.join(output_dir, file))
        elif not any(fnmatch.fnmatch(file, ext) for ext in weights_patterns):
            shutil.copy(os.path.join(input_dir, file), output_dir)
81
82
83
84
85
86
87
88
89
90


def _run_generate(input_dir, queue: mp.Queue, **kwargs):
    llm = LLM(model=input_dir, **kwargs)
    gen = llm.generate(prompts, sampling_params)
    queue.put([g.outputs[0].__dict__ for g in gen])
    queue.close()
    queue.join_thread()


91
@pytest.mark.parametrize("enable_lora", [False, True])
92
93
@pytest.mark.parametrize("tp_size", [1, 2])
def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
94
95
                              llama_3p2_1b_files,
                              monkeypatch: pytest.MonkeyPatch):
96
97
    if num_gpus_available < tp_size:
        pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
98

99
100
    weights_patterns = ("*.safetensors", )
    gpu_memory_utilization = 0.8
101
    input_dir = llama_3p2_1b_files
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    ctx = mp.get_context("spawn")

    # Run in separate processes for memory & CUDA isolation
    with TemporaryDirectory() as output_dir:
        p = ctx.Process(target=_run_writer,
                        args=(input_dir, output_dir, weights_patterns),
                        kwargs=dict(
                            tensor_parallel_size=tp_size,
                            gpu_memory_utilization=gpu_memory_utilization,
                            enforce_eager=True,
                        ))
        p.start()
        p.join()

        queue = ctx.Queue()

        p = ctx.Process(target=_run_generate,
                        args=(input_dir, queue),
                        kwargs=dict(
                            enable_lora=enable_lora,
                            gpu_memory_utilization=gpu_memory_utilization,
                            tensor_parallel_size=tp_size,
                        ))
        p.start()
126
127
128
129
130
        # Call queue.get() before p.join() to prevent deadlock:
        # If p.join() is called before queue.get() and the queue is full,
        # the child process may block while writing to the queue and never
        # terminate, causing the parent to wait indefinitely on p.join().
        # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
131
        out_before = queue.get()
132
133
134
135
136
        p.join()
        queue.close()
        queue.join_thread()

        queue = ctx.Queue()
137
138
139
140
141
142
143
144
145
146

        p = ctx.Process(target=_run_generate,
                        args=(output_dir, queue),
                        kwargs=dict(
                            enable_lora=enable_lora,
                            gpu_memory_utilization=gpu_memory_utilization,
                            tensor_parallel_size=tp_size,
                            load_format="sharded_state",
                        ))
        p.start()
147
148
149
150
151
        # Call queue.get() before p.join() to prevent deadlock:
        # If p.join() is called before queue.get() and the queue is full,
        # the child process may block while writing to the queue and never
        # terminate, causing the parent to wait indefinitely on p.join().
        # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
152
        out_after = queue.get()
153
154
155
        p.join()
        queue.close()
        queue.join_thread()
156
157

        assert out_before == out_after