cpu_worker.py 9.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import platform
5
from collections.abc import Callable
6
from typing import Any
7
8
9
10
11
12

import torch

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
13
from vllm.platforms import CpuArchEnum, current_platform
14
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
15
from vllm.profiler.wrapper import TorchProfilerWrapper
16
from vllm.utils.torch_utils import set_random_seed
17
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
18
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
19
20
21
22
23

logger = init_logger(__name__)


class CPUWorker(Worker):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
        super().__init__(
            vllm_config,
            local_rank,
            rank,
            distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
39
40
41

        self.parallel_config.disable_custom_all_reduce = True

42
        # Torch profiler. Enabled and configured through profiler_config.
43
        self.profiler: Any | None = None
44
45
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
46
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
47
48
49
50
51
            self.profiler = TorchProfilerWrapper(
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU"],
52
53
            )

54
    def init_device(self):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        # Check whether critical libraries are loaded
        def check_preloaded_libs(name: str):
            ld_preload_list = os.environ.get("LD_PRELOAD", "")
            if name not in ld_preload_list:
                raise RuntimeError(
                    f"{name} is not found in LD_PRELOAD. "
                    "Please follow the section `set LD_PRELOAD` in "
                    "https://docs.vllm.ai/en/latest/getting_started/installation/cpu/ "
                    "to setup required pre-loaded libraries."
                )

        check_preloaded_libs("libtcmalloc")
        if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
            check_preloaded_libs("libiomp")

70
71
        # Setup OpenMP threads affinity.
        omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
72
        # Under numa binding some cores reserved for kv transfer in nixl_connector.py
73
        if omp_cpuids == "auto" and platform.system() == "Linux":
74
75
76
            cpu_arch = current_platform.get_cpu_architecture()
            if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
                # For S390X/POWERPC SMT-8/4/2
77
                self.local_omp_cpuid = self._get_autobind_cpu_ids(
78
79
                    lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
                )
80
            elif cpu_arch == CpuArchEnum.X86:
81
82
                # For x86 SMT-2, use 1 CPU per core
                self.local_omp_cpuid = self._get_autobind_cpu_ids(
83
84
                    lambda cpus: cpus[-1:]
                )
85
86
87
            elif cpu_arch == CpuArchEnum.ARM:
                # For AArch64, no SMT
                self.local_omp_cpuid = self._get_autobind_cpu_ids(lambda cpus: cpus)
88
            else:
89
90
91
                self.local_omp_cpuid = "nobind"
        elif omp_cpuids == "nobind":
            self.local_omp_cpuid = "nobind"
92
        else:
93
            local_dp_rank = self.parallel_config.data_parallel_rank_local
94
            omp_cpuids_list = omp_cpuids.split("|")
95
96
            if local_dp_rank is not None:
                world_size = self.parallel_config.world_size
97
                omp_cpuids_list = omp_cpuids_list[
98
99
                    local_dp_rank * world_size : (local_dp_rank + 1) * world_size
                ]
100
            self.local_omp_cpuid = omp_cpuids_list[self.rank]
101

102
        if self.local_omp_cpuid != "nobind":
103
            ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
104
105
106
107
            if ret:
                logger.info(ret)

        # Note: unique identifier for creating allreduce shared memory
108
        os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
109
        # Initialize the distributed environment.
110
111
112
113
114
115
116
        init_worker_distributed_environment(
            self.vllm_config,
            self.rank,
            self.distributed_init_method,
            self.local_rank,
            current_platform.dist_backend,
        )
117
118
119
120
121
        # Set random seed.
        set_random_seed(self.model_config.seed)

        # Construct the model runner
        self.model_runner: CPUModelRunner = CPUModelRunner(
122
123
            self.vllm_config, torch.device("cpu")
        )
124
125
126
127
128

    def sleep(self, level: int = 1) -> None:
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

129
    def wake_up(self, tags: list[str] | None = None) -> None:
130
131
132
133
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

    def determine_available_memory(self) -> int:
134
        return self.cache_config.cpu_kvcache_space_bytes or 0
135

136
    def compile_or_warm_up_model(self) -> float:
137
138
139
140
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
        self.model_runner.warming_up_model()
141
        return self.compilation_config.compilation_time
142

143
    def _get_autobind_cpu_ids(
144
        self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
145
    ) -> str:
146
        """
147
148
        Return CPU ids to bind based on NUMA nodes.
        Currently for rank N, only CPU ids on the N-th node in available NUMA
149
150
        node list will be selected.
        Args:
151
            cpu_selector: a callable object to select CPUs from a CPU list
152
            of a physical core. The input is a LogicalCPUInfo list, sorted by
153
            the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
154
            returned.
155
        """
156
157
        # simulate multiple numa nodes, for testing
        sim_multi_numa_nodes = os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", "0") != "0"
158

159
        allowed_numa_nodes, logical_cpu_list = (
160
            CpuPlatform.get_allowed_cpu_core_node_list()
161
        )
162
163
        local_world_size = self.parallel_config.local_world_size
        assert len(allowed_numa_nodes) >= local_world_size or sim_multi_numa_nodes, (
smashyalts's avatar
smashyalts committed
164
            f"Not enough allowed NUMA nodes to bind threads of "
165
            f"{local_world_size} local CPUWorkers. "
166
            f"Allowed NUMA nodes are {allowed_numa_nodes}. "
167
168
            "Please try to bind threads manually."
        )
169

170
171
172
173
174
175
176
        if not sim_multi_numa_nodes:
            # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`
            selected_numa_node = allowed_numa_nodes[self.local_rank]  # type: ignore
            logical_cpu_list = [
                x for x in logical_cpu_list if x.numa_node == selected_numa_node
            ]
        else:
177
178
179
180
181
            # This is a bit tricky because the internal DP size
            # is always 1 for non-MoE models
            world_size_across_dp = (
                self.parallel_config.world_size
                * self.parallel_config._api_process_count
182
            )
183
184
185
186
187
188
189
190
191
            assert len(logical_cpu_list) >= world_size_across_dp
            logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
            sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
            assert self.parallel_config.data_parallel_rank_local is not None
            start_idx = (
                self.local_rank
                + self.parallel_config.world_size
                * self.parallel_config.data_parallel_rank_local
            ) * sim_cpu_num_per_node
192
193
194
            logical_cpu_list = logical_cpu_list[
                start_idx : (start_idx + sim_cpu_num_per_node)
            ]
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

        # Select CPUs from each physical core via cpu_selector
        core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
        for cpu_info in logical_cpu_list:
            if cpu_info.physical_core not in core_to_cpus:
                core_to_cpus[cpu_info.physical_core] = []
            core_to_cpus[cpu_info.physical_core].append(cpu_info)
        logical_cpu_list = []
        for cpu_list in core_to_cpus.values():
            cpu_list = sorted(cpu_list, key=lambda x: x.id)
            logical_cpu_list.extend(cpu_selector(cpu_list))
        logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)

        # Reserve CPUs for other processes
        reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
        if reserve_cpu_num is None:
211
212
213
214
            need_reserve = (
                self.parallel_config.world_size > 1
                or self.parallel_config.data_parallel_size_local > 1
            )
215
            reserve_cpu_num = 1 if need_reserve else 0
216
217
        assert len(logical_cpu_list) > reserve_cpu_num, (
            f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
218
219
            f"should less than {len(logical_cpu_list)}."
        )
220
221
222
        if reserve_cpu_num != 0:
            logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]

223
224
225
226
        logger.info(
            "auto thread-binding list (id, physical core): %s",
            [(x.id, x.physical_core) for x in logical_cpu_list],
        )
227
        return ",".join([str(x.id) for x in logical_cpu_list])
228

229
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
230
231
232
233
234
235
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()