cpu_worker.py 10 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from importlib import util
5
6
7
8
9
10
11
12
13
from typing import Optional

import torch

from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
14
from vllm.platforms import CpuArchEnum, current_platform
15
from vllm.sequence import IntermediateTensors
16
from vllm.utils import PlaceholderModule
17
18
19
20
21
22
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
                                       init_worker_distributed_environment)

23
24
25
26
27
28
29
try:
    import psutil
    from numa import info
except ImportError:
    psutil = PlaceholderModule("psutil")  # type: ignore[assignment]
    numa = PlaceholderModule("numa")  # type: ignore[assignment]

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
logger = init_logger(__name__)


class CPUWorker(Worker):

    def __init__(self,
                 vllm_config: VllmConfig,
                 local_rank: int,
                 rank: int,
                 distributed_init_method: str,
                 is_driver_worker: bool = False):
        super().__init__(vllm_config,
                         local_rank,
                         rank,
                         distributed_init_method,
                         is_driver_worker=is_driver_worker)

        self.parallel_config.disable_custom_all_reduce = True
48
49
        self.manually_bind_threads_suggestion = (
            "To get better performance, please try to manually bind threads.")
50
51
52
53

    def init_device(self):
        # Setup OpenMP threads affinity.
        omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
54
55
        self.local_omp_cpuid = "all"
        if omp_cpuids == "auto":
56
57
58
59
60
61
            if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
                self.local_omp_cpuid = (
                    self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
            else:
                self.local_omp_cpuid = (
                    self.get_cpus_id_binding_based_on_numa_nodes())
62
63
        else:
            self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
64
65

        if self.local_omp_cpuid != "all":
66
67
68
69
70
71
72
73
74
75
            ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
            if ret:
                logger.info(ret)

        # Note: unique identifier for creating allreduce shared memory
        os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
            ":")[-1]
        # Initialize the distributed environment.
        init_worker_distributed_environment(self.vllm_config, self.rank,
                                            self.distributed_init_method,
76
77
                                            self.local_rank,
                                            current_platform.dist_backend)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        # Set random seed.
        set_random_seed(self.model_config.seed)

        # Construct the model runner
        self.model_runner: CPUModelRunner = CPUModelRunner(
            self.vllm_config, torch.device("cpu"))

    def sleep(self, level: int = 1) -> None:
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

    def wake_up(self, tags: Optional[list[str]] = None) -> None:
        logger.warning("sleep mode is not supported on CPU, ignore it.")
        pass

    def determine_available_memory(self) -> int:
        return self.cache_config.cpu_kvcache_space_bytes  # type: ignore

    def compile_or_warm_up_model(self) -> None:
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
        self.model_runner.warming_up_model()

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> Optional[ModelRunnerOutput]:
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))

        output = self.model_runner.execute_model(scheduler_output,
                                                 intermediate_tensors)

        if not get_pp_group().is_last_rank:
            assert isinstance(output, IntermediateTensors)
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
            return None

        assert isinstance(output, ModelRunnerOutput)
        return output if self.is_driver_worker else None
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    def warn_inability_to_detect_numa(self) -> None:
        logger.warning(
            "Auto thread-binding failed due to the "
            "inability to detect numa nodes. %s",
            self.manually_bind_threads_suggestion)

    def warn_lack_of_numa_and_psutil(self) -> None:
        logger.warning(
            "Auto thread-binding failed due to "
            "the lack of package numa and psutil. %s",
            self.manually_bind_threads_suggestion)

    def warn_world_size_too_large(self, world_size: int,
                                  node_to_cpus_len: int) -> None:
        logger.warning(
            "Auto thread-binding failed due to "
            "world size: %d being larger than "
            "allowed NUMA nodes number: %d. %s", world_size, node_to_cpus_len,
            self.manually_bind_threads_suggestion)

    def get_cpus_allow_list_and_numa_size(self):
        cpus_allow_list = psutil.Process().cpu_affinity()
        numa_size = info.get_num_configured_nodes()
        return cpus_allow_list, numa_size

    def auto_thread_binding_based_on_numa_nodes(self, world_size: int,
                                                rank_to_cpus: str) -> str:
        cpu_count = psutil.cpu_count(logical=False)
        cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
        if not numa_size:
            self.warn_inability_to_detect_numa()
            return rank_to_cpus

        cpu_count_per_numa = cpu_count // numa_size
        num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
                                  cpu_count_per_numa // 2)

        node_to_cpus = []
        for i in range(numa_size):
            node_intersect = set(
                info.node_to_cpus(i)).intersection(cpus_allow_list)
            if bool(node_intersect):
                node_to_cpus.append(list(node_intersect))

        node_to_cpus_len = len(node_to_cpus)
        if world_size > node_to_cpus_len:
            self.warn_world_size_too_large(world_size, node_to_cpus_len)
        else:
            end = cpu_count_per_numa - num_of_reserved_cpu
            rank_to_cpus_list = node_to_cpus[self.rank][:end]
            rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
            logger.info("auto thread-binding list: %s", rank_to_cpus)
        return rank_to_cpus

    def libnuma_and_psutil_found(self) -> bool:
        libnuma_found = util.find_spec("numa") is not None
        psutil_found = util.find_spec("psutil") is not None

        return libnuma_found and psutil_found

185
186
187
188
189
190
    def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
        """Return CPUs id binding based on NUMA nodes.
        """
        rank_to_cpus = self.local_omp_cpuid
        # Setup OpenMP thread affinity based on NUMA nodes automatically
        world_size = self.vllm_config.parallel_config.world_size
191
192
193
194
195
196
        if self.libnuma_and_psutil_found():
            rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes(
                world_size, rank_to_cpus)
        else:
            self.warn_lack_of_numa_and_psutil()
        return rank_to_cpus
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def select_threads_per_power_core(self,
                                      node_cpu_ids: list[int]) -> list[int]:
        return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]

    def auto_thread_binding_based_on_numa_nodes_ppc64le(
            self, world_size: int, rank_to_cpus: str) -> str:
        cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
        if not numa_size:
            self.warn_inability_to_detect_numa()
            return rank_to_cpus

        node_to_cpus = []
        for i in range(numa_size):
            node_intersect = set(
                info.node_to_cpus(i)).intersection(cpus_allow_list)
            if bool(node_intersect):
                node_to_cpus.append(sorted(list(node_intersect)))

        node_to_cpus_len = len(node_to_cpus)
        if world_size > node_to_cpus_len:
            self.warn_world_size_too_large(world_size, node_to_cpus_len)
219
        else:
220
221
222
223
224
225
226
227
228
229
            node_cpus_this_rank = node_to_cpus[self.rank]
            node_cpus_this_rank = self.select_threads_per_power_core(
                node_cpus_this_rank)
            cpu_count_per_numa = len(node_cpus_this_rank)
            num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
                                      cpu_count_per_numa // 2)
            end = cpu_count_per_numa - num_of_reserved_cpu
            rank_to_cpus_list = node_cpus_this_rank[:end]
            rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
            logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
230
        return rank_to_cpus
231
232
233
234
235
236
237
238
239
240
241

    def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
        """
        Power (ppc64le) specific: Selects a subset of threads per core for 
        each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc) 
        because the OS only exposes available threads.This maximizes 
        performance by avoiding oversubscription of logical CPUs on Power.
        """

        rank_to_cpus = self.local_omp_cpuid
        world_size = self.vllm_config.parallel_config.world_size
242
243
244
        if self.libnuma_and_psutil_found():
            rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes_ppc64le(
                world_size, rank_to_cpus)
245
        else:
246
            self.warn_lack_of_numa_and_psutil()
247
        return rank_to_cpus