cpu_worker.py 7.37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import platform
5
from collections.abc import Callable
6
from typing import Any
7
8
9
10
11
12

import torch

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
13
from vllm.platforms import CpuArchEnum, current_platform
14
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
15
from vllm.profiler.wrapper import TorchProfilerWrapper
16
from vllm.utils.torch_utils import set_random_seed
17
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
18
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
19
20
21
22
23

logger = init_logger(__name__)


class CPUWorker(Worker):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
        super().__init__(
            vllm_config,
            local_rank,
            rank,
            distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
39
40
41

        self.parallel_config.disable_custom_all_reduce = True

42
        # Torch profiler. Enabled and configured through profiler_config.
43
        self.profiler: Any | None = None
44
45
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
46
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
47
48
49
50
51
            self.profiler = TorchProfilerWrapper(
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU"],
52
53
            )

54
55
56
    def init_device(self):
        # Setup OpenMP threads affinity.
        omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
57
        if omp_cpuids == "auto" and platform.system() == "Linux":
58
59
60
            cpu_arch = current_platform.get_cpu_architecture()
            if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
                # For S390X/POWERPC SMT-8/4/2
61
                self.local_omp_cpuid = self._get_autobind_cpu_ids(
62
63
                    lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
                )
64
            elif cpu_arch == CpuArchEnum.X86:
65
66
                # For x86 SMT-2, use 1 CPU per core
                self.local_omp_cpuid = self._get_autobind_cpu_ids(
67
68
                    lambda cpus: cpus[-1:]
                )
69
            else:
70
71
72
                self.local_omp_cpuid = "nobind"
        elif omp_cpuids == "nobind":
            self.local_omp_cpuid = "nobind"
73
        else:
74
            local_dp_rank = self.parallel_config.data_parallel_rank_local
75
            omp_cpuids_list = omp_cpuids.split("|")
76
77
            if local_dp_rank is not None:
                world_size = self.parallel_config.world_size
78
                omp_cpuids_list = omp_cpuids_list[
79
80
                    local_dp_rank * world_size : (local_dp_rank + 1) * world_size
                ]
81
            self.local_omp_cpuid = omp_cpuids_list[self.rank]
82

83
        if self.local_omp_cpuid != "nobind":
84
85
86
87
88
            ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
            if ret:
                logger.info(ret)

        # Note: unique identifier for creating allreduce shared memory
89
        os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
90
        # Initialize the distributed environment.
91
92
93
94
95
96
97
        init_worker_distributed_environment(
            self.vllm_config,
            self.rank,
            self.distributed_init_method,
            self.local_rank,
            current_platform.dist_backend,
        )
98
99
100
101
102
        # Set random seed.
        set_random_seed(self.model_config.seed)

        # Construct the model runner
        self.model_runner: CPUModelRunner = CPUModelRunner(
103
104
            self.vllm_config, torch.device("cpu")
        )
105
106
107
108
109

    def sleep(self, level: int = 1) -> None:
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

110
    def wake_up(self, tags: list[str] | None = None) -> None:
111
112
113
114
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

    def determine_available_memory(self) -> int:
115
        return self.cache_config.cpu_kvcache_space_bytes or 0
116
117
118
119
120
121
122

    def compile_or_warm_up_model(self) -> None:
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
        self.model_runner.warming_up_model()

123
    def _get_autobind_cpu_ids(
124
        self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
125
    ) -> str:
126
        """
127
128
        Return CPU ids to bind based on NUMA nodes.
        Currently for rank N, only CPU ids on the N-th node in available NUMA
129
130
        node list will be selected.
        Args:
131
            cpu_selector: a callable object to select CPUs from a CPU list
132
            of a physical core. The input is a LogicalCPUInfo list, sorted by
133
            the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
134
            returned.
135
136
        """

137
        allowed_numa_nodes, logical_cpu_list = (
138
            CpuPlatform.get_allowed_cpu_core_node_list()
139
        )
140
141
142
143
        assert len(allowed_numa_nodes) >= self.parallel_config.world_size, (
            f"No enough allowed NUMA nodes to bind threads of "
            f"{self.parallel_config.world_size} CPUWorkers. "
            f"Allowed NUMA nodes are {allowed_numa_nodes}. "
144
145
            "Please try to bind threads manually."
        )
146

147
        # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`
148
        selected_numa_node = allowed_numa_nodes[self.local_rank]  # type: ignore
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        logical_cpu_list = [
            x for x in logical_cpu_list if x.numa_node == selected_numa_node
        ]

        # Select CPUs from each physical core via cpu_selector
        core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
        for cpu_info in logical_cpu_list:
            if cpu_info.physical_core not in core_to_cpus:
                core_to_cpus[cpu_info.physical_core] = []
            core_to_cpus[cpu_info.physical_core].append(cpu_info)
        logical_cpu_list = []
        for cpu_list in core_to_cpus.values():
            cpu_list = sorted(cpu_list, key=lambda x: x.id)
            logical_cpu_list.extend(cpu_selector(cpu_list))
        logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)

        # Reserve CPUs for other processes
        reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
        if reserve_cpu_num is None:
168
169
170
171
            need_reserve = (
                self.parallel_config.world_size > 1
                or self.parallel_config.data_parallel_size_local > 1
            )
172
            reserve_cpu_num = 1 if need_reserve else 0
173
174
        assert len(logical_cpu_list) > reserve_cpu_num, (
            f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
175
176
            f"should less than {len(logical_cpu_list)}."
        )
177
178
179
        if reserve_cpu_num != 0:
            logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]

180
181
182
183
        logger.info(
            "auto thread-binding list (id, physical core): %s",
            [(x.id, x.physical_core) for x in logical_cpu_list],
        )
184
        return ",".join([str(x.id) for x in logical_cpu_list])
185
186
187
188
189
190
191
192

    def profile(self, is_start: bool = True):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()