Unverified Commit d02421a7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU affinity and memory management (#39781)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent b1dc87a0
......@@ -23,22 +23,22 @@ if [ "$failed_req" -ne 0 ]; then
exit 1
fi
#echo "--- DP+TP"
#vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
#server_pid=$!
#timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
#vllm bench serve \
# --backend vllm \
# --dataset-name random \
# --model meta-llama/Llama-3.2-3B-Instruct \
# --num-prompts 20 \
# --result-dir ./test_results \
# --result-filename dp_pp.json \
# --save-result \
# --endpoint /v1/completions
#kill -s SIGTERM $server_pid; wait $server_pid || true
#failed_req=$(jq '.failed' ./test_results/dp_pp.json)
#if [ "$failed_req" -ne 0 ]; then
# echo "Some requests were failed!"
# exit 1
#fi
echo "--- DP+TP"
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
vllm bench serve \
--backend vllm \
--dataset-name random \
--model meta-llama/Llama-3.2-3B-Instruct \
--num-prompts 20 \
--result-dir ./test_results \
--result-filename dp_pp.json \
--save-result \
--endpoint /v1/completions
kill -s SIGTERM $server_pid; wait $server_pid || true
failed_req=$(jq '.failed' ./test_results/dp_pp.json)
if [ "$failed_req" -ne 0 ]; then
echo "Some requests were failed!"
exit 1
fi
......@@ -45,6 +45,7 @@ jobs:
- name: Smoke test vllm serve
run: |
# Start server in background
VLLM_CPU_KVCACHE_SPACE=1 \
vllm serve Qwen/Qwen3-0.6B \
--max-model-len=2K \
--load-format=dummy \
......
......@@ -30,6 +30,21 @@ else()
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
# and create a local shim dir with it
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
find_library(OPEN_MP
NAMES gomp
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
NO_DEFAULT_PATH
REQUIRED
)
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
if (OPEN_MP)
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
endif()
endif()
if (NOT MACOSX_FOUND)
......@@ -175,20 +190,6 @@ if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND
if(NOT NPROC)
set(NPROC 4)
endif()
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
# and create a local shim dir with it
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
find_library(OPEN_MP
NAMES gomp
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
NO_DEFAULT_PATH
REQUIRED
)
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
if (OPEN_MP)
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
endif()
# Fetch and populate ACL
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}")
......
......@@ -141,6 +141,8 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
torch::Tensor slot_mapping,
const int64_t block_size);
void init_cpu_memory_env(std::vector<int64_t> node_ids);
namespace cpu_utils {
void eagle_prepare_inputs_padded_kernel_impl(
const torch::Tensor& cu_num_draft_tokens,
......@@ -431,6 +433,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"block_size) -> ()",
&compute_slot_mapping_kernel_impl);
ops.def("init_cpu_memory_env(SymInt[] node_ids) -> ()", &init_cpu_memory_env);
// Speculative decoding kernels
ops.def(
"eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, "
......
......@@ -13,13 +13,80 @@
#include "cpu/utils.hpp"
#ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
return std::string(
"Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has "
"no effect to setup thread affinity.");
}
void init_cpu_memory_env(std::vector<int64_t> node_ids) {}
#else
void init_cpu_memory_env(std::vector<int64_t> node_ids) {
// Memory node binding
if (numa_available() != -1) {
// Concatenate all node_ids into a single comma-separated string
if (!node_ids.empty()) {
std::string node_ids_str;
for (const int node_id : node_ids) {
if (!node_ids_str.empty()) {
node_ids_str += ",";
}
node_ids_str += std::to_string(node_id);
}
#endif
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
bitmask* src_mask = numa_get_mems_allowed();
int pid = getpid();
if (mask && src_mask) {
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_WARN("numa_migrate_pages failed. errno: " +
std::to_string(errno));
}
// Restrict memory allocation to the selected NUMA node(s).
// Enhances memory locality for the threads bound to those NUMA CPUs.
if (node_ids.size() > 1) {
errno = 0;
numa_set_interleave_mask(mask);
if (errno != 0) {
TORCH_WARN("numa_set_interleave_mask failed. errno: " +
std::to_string(errno));
} else {
TORCH_WARN(
"NUMA binding: Using INTERLEAVE policy for memory "
"allocation across multiple NUMA nodes (nodes: " +
node_ids_str +
"). Memory allocations will be "
"interleaved across the specified NUMA nodes.");
}
} else {
errno = 0;
numa_set_membind(mask);
if (errno != 0) {
TORCH_WARN("numa_set_membind failed. errno: " +
std::to_string(errno));
} else {
TORCH_WARN(
"NUMA binding: Using MEMBIND policy for memory "
"allocation on the NUMA nodes (" +
node_ids_str +
"). Memory allocations will be "
"strictly bound to these NUMA nodes.");
}
}
numa_set_strict(1);
numa_free_nodemask(mask);
numa_free_nodemask(src_mask);
} else {
TORCH_WARN(
"numa_parse_nodestring or numa_get_run_node_mask failed. errno: " +
std::to_string(errno));
}
}
}
}
#endif // VLLM_NUMA_DISABLED
namespace cpu_utils {
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
......
......@@ -173,7 +173,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
COPY --from=vllm-test-deps /vllm-workspace/requirements/test/cpu.txt requirements/test/cpu.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -r requirements/dev.txt && \
uv pip install -r requirements/lint.txt && \
uv pip install -r requirements/test/cpu.txt && \
pre-commit install --hook-type pre-commit --hook-type commit-msg
ENTRYPOINT ["bash"]
......
......@@ -46,7 +46,7 @@ AITER_MODEL_LIST = [
),
pytest.param(
"openai-community/gpt2", # gpt2
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
marks=[pytest.mark.core_model],
),
pytest.param("Milos/slovak-gpt-j-405M"), # gptj
pytest.param("bigcode/tiny_starcoder_py"), # gpt_bigcode
......@@ -143,11 +143,6 @@ def test_models(
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
if current_platform.is_cpu() and model in ("openai-community/gpt2",):
# These models are sensitive to the rounding error
# Fuse ops to reduce rounding
monkeypatch.setenv("VLLM_CPU_CI_ENV", "0")
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs
......
......@@ -101,8 +101,6 @@ class CacheConfig:
kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
"""Layer patterns to skip KV cache quantization. Accepts layer indices
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
cpu_kvcache_space_bytes: int | None = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: int | None = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
......@@ -183,7 +181,6 @@ class CacheConfig:
"num_gpu_blocks_override",
"enable_prefix_caching",
"prefix_caching_hash_algo",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
"user_specified_block_size",
"user_specified_mamba_block_size",
......
......@@ -6,15 +6,16 @@ import os
import platform
import subprocess
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING
import psutil
import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils.ompmultiprocessing import OMPProcessManager
from vllm.utils.cpu_resource_utils import (
DEVICE_CONTROL_ENV_VAR,
get_memory_node_info,
)
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum
......@@ -38,49 +39,13 @@ def get_max_threads(pid=0):
raise NotImplementedError("Unsupported OS")
@dataclass
class LogicalCPUInfo:
id: int = -1
physical_core: int = -1
numa_node: int = -1
@classmethod
def _int(cls, value: str) -> int:
try:
int_value = int(value)
except Exception:
int_value = -1
return int_value
@staticmethod
def json_decoder(obj_dict: dict):
id = obj_dict.get("cpu")
physical_core = obj_dict.get("core")
numa_node = obj_dict.get("node")
if not (id is None or physical_core is None or numa_node is None):
return LogicalCPUInfo(
id=LogicalCPUInfo._int(id),
physical_core=LogicalCPUInfo._int(physical_core),
numa_node=LogicalCPUInfo._int(numa_node),
)
else:
return obj_dict
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"
dist_backend: str = "gloo"
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
omp_process_manager = None
# Simultaneous Multithreading (SMT) level for OpenMP:
# 4 on PowerPC, 1 on non-PowerPC architectures
smt = 1
global_cpu_mask = None
simulate_numa = int(os.environ.get("_SIM_MULTI_NUMA", 0))
device_control_env_var = DEVICE_CONTROL_ENV_VAR
@property
def supported_dtypes(self) -> list[torch.dtype]:
......@@ -123,29 +88,9 @@ class CpuPlatform(Platform):
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import format_gib
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
node_dir = "/sys/devices/system/node"
if kv_cache_space is None:
nodes = (
[d for d in os.listdir(node_dir) if d.startswith("node")]
if os.path.exists(node_dir)
else []
)
num_numa_nodes = len(nodes) or 1
free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
DEFAULT_CPU_MEM_UTILIZATION = 0.5
kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
logger.warning_once(
"VLLM_CPU_KVCACHE_SPACE not set. Using %s GiB for KV cache.",
format_gib(kv_cache_space),
)
else:
kv_cache_space *= GiB_bytes
meminfo = get_memory_node_info(device_id)
return kv_cache_space
return meminfo.total_memory
@classmethod
def set_device(cls, device: torch.device) -> None:
......@@ -180,6 +125,12 @@ class CpuPlatform(Platform):
"otherwise the performance is not optimized."
)
# Lagecy setting
env_key = "VLLM_CPU_KVCACHE_SPACE"
if env_key in os.environ and os.environ[env_key] != "":
kv_cache_space = int(os.environ[env_key])
cache_config.kv_cache_memory_bytes = kv_cache_space * GiB_bytes
scheduler_config = vllm_config.scheduler_config
# async scheduling is not required on CPU
scheduler_config.async_scheduling = False
......@@ -198,8 +149,6 @@ class CpuPlatform(Platform):
)
cache_config.cache_dtype = "auto"
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
parallel_config = vllm_config.parallel_config
# OMP requires the MP executor to function correctly, UniProc is not
# supported as it is not possible to set the OMP environment correctly
......@@ -278,21 +227,45 @@ class CpuPlatform(Platform):
os.environ["TORCHINDUCTOR_CPP_DYNAMIC_THREADS"] = "1"
ld_preload_str = os.getenv("LD_PRELOAD", "")
# Intel and CLANG OpenMP setting
if "libiomp5.so" in ld_preload_str or "libomp5" in ld_preload_str:
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
os.environ["KMP_BLOCKTIME"] = "1"
# Prevents the CPU to run into low performance state
os.environ["KMP_TPAUSE"] = "0"
# Provides fine granularity parallelism
os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
cpu_architecture = Platform.get_cpu_architecture()
if (
platform.system() == "Linux"
and cpu_architecture
in (CpuArchEnum.ARM, CpuArchEnum.POWERPC, CpuArchEnum.X86)
and not (
"libomp" in ld_preload_str
or "libgomp" in ld_preload_str
or "libiomp" in ld_preload_str
)
):
# We need to LD_PRELOAD PyTorch's libgomp, otherwise only
# one core will be properly utilized when we thread-bind
# See: https://github.com/vllm-project/vllm/issues/27369
# TODO: Remove once:
# https://github.com/pytorch/pytorch/issues/166087 is fixed
# We need to find the location of PyTorch's libgomp
torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg)
# Search both torch.libs and torch/lib - See:
# https://github.com/vllm-project/vllm/issues/30470
torch_libs_paths = [
os.path.join(site_root, "torch.libs"),
os.path.join(torch_pkg, "lib"),
]
pytorch_libgomp_so_candidates = []
for torch_libs in torch_libs_paths:
pytorch_libgomp_so_candidates.extend(
glob.glob(os.path.join(torch_libs, "libgomp*.so*"))
)
if pytorch_libgomp_so_candidates:
pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
if ld_preload_str:
ld_preload_str += ":"
ld_preload_str += pytorch_libgomp_so
os.environ["LD_PRELOAD"] = ld_preload_str
# LD_PRELOAD libtcmalloc, bundled under vllm/libs to reduce
# memory allocation overhead
if (
......@@ -331,13 +304,6 @@ class CpuPlatform(Platform):
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
# CI specific "quick" NUMA simulation - split all available CPUs
# into a fake NUMA topology
if os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", None) is not None:
os.environ["_SIM_MULTI_NUMA"] = str(
vllm_config.parallel_config.world_size
* vllm_config.parallel_config._api_process_count
)
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
......@@ -345,78 +311,6 @@ class CpuPlatform(Platform):
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod
def get_omp_manager(cls) -> OMPProcessManager:
# initialise the OMP resource management if need be and return the manager
if cls.omp_process_manager is None:
if cls.get_cpu_architecture() == CpuArchEnum.POWERPC:
cls.smt = 4
cls.omp_process_manager = OMPProcessManager(
affinity=cls.get_global_cpu_mask(), smt=cls.smt
)
# we need to fix up the topology returned by the OMP Manager for
# simulated NUMA environments in CI
if cls.simulate_numa > 0:
logger.info(
"Adjusting numa topology to resemble at least %d nodes",
int(cls.simulate_numa),
)
om = cls.omp_process_manager
while len(om.omp_places) < cls.simulate_numa:
new_omp_places = []
touched = False
for omp_place in om.omp_places:
if len(omp_place["mask"]) > 1:
touched = True
cpu_list = sorted(list(omp_place["mask"]))
new_omp_places.append(
{
"mask": set(cpu_list[0 : int(len(cpu_list) / 2)]),
"available": True,
}
)
new_omp_places.append(
{
"mask": set(cpu_list[int(len(cpu_list) / 2) :]),
"available": True,
}
)
if touched:
om.omp_places = new_omp_places
else:
raise ValueError(
"Cannot split the existing NUMA topology to match "
"simulation requirements"
)
return cls.omp_process_manager
@classmethod
def get_global_cpu_mask(cls) -> set[int]:
# get global cpu mask
if cls.global_cpu_mask is None:
if hasattr(os, "sched_getaffinity"):
cls.global_cpu_mask = os.sched_getaffinity(0)
else:
# macOS does not support sched_getaffinity
cpu_count = os.cpu_count() or 1
cls.global_cpu_mask = set(range(cpu_count))
return cls.global_cpu_mask
@classmethod
def reserve_cpus(cls, reserve: set[int]) -> bool:
# remove CPUs from global mask, for now there is no "release" mechanism
if cls.omp_process_manager is not None:
for place in cls.omp_process_manager.omp_places:
if not place["available"]:
return False
cls.global_cpu_mask = cls.get_global_cpu_mask() - reserve
# reinitialize OMP resource management
cls.omp_process_manager = OMPProcessManager(
affinity=cls.global_cpu_mask, smt=cls.smt
)
return True
@classmethod
def discover_numa_topology(cls) -> list[list[int]]:
"""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import platform
import subprocess
from dataclasses import dataclass
from functools import cache
import psutil
import regex as re
DEVICE_CONTROL_ENV_VAR = "CPU_VISIBLE_MEMORY_NODES"
@dataclass
class LogicalCPUInfo:
id: int = -1
physical_core: int = -1
numa_node: int = -1
@classmethod
def _int(cls, value: str) -> int:
try:
int_value = int(value)
except Exception:
int_value = -1
return int_value
@staticmethod
def json_decoder(obj_dict: dict):
id = obj_dict.get("cpu")
physical_core = obj_dict.get("core")
numa_node = obj_dict.get("node")
if not (id is None or physical_core is None or numa_node is None):
return LogicalCPUInfo(
id=LogicalCPUInfo._int(id),
physical_core=LogicalCPUInfo._int(physical_core),
numa_node=LogicalCPUInfo._int(numa_node),
)
else:
return obj_dict
@dataclass
class MemoryNodeInfo:
total_memory: int = -1
available_memory: int = -1
def get_memory_affinity(pid: int = 0) -> list[int]:
pid = os.getpid() if pid == 0 else pid
path = f"/proc/{pid}/status"
with open(path) as f:
for line in f:
if line.startswith("Mems_allowed_list:"):
# Extract the string part (e.g., "0-1,3")
raw_list = line.split(":")[1].strip()
return parse_id_list(raw_list)
return []
def parse_id_list(raw_str: str) -> list[int]:
"""Parses strings like '0-2,4,7-8' into [0, 1, 2, 4, 7, 8]"""
result: list[int] = []
if not raw_str:
return result
for part in raw_str.split(","):
if "-" in part:
start, end = map(int, part.split("-"))
result.extend(range(start, end + 1))
else:
result.append(int(part))
return sorted(list(set(result)))
def get_memory_node_info(node_id: int = 0) -> MemoryNodeInfo:
if platform.system() == "Darwin":
# MacOS has no memory node
return MemoryNodeInfo(
total_memory=psutil.virtual_memory().total,
available_memory=psutil.virtual_memory().available,
)
meminfo_path = f"/sys/devices/system/node/node{node_id}/meminfo"
if not os.path.exists(meminfo_path):
raise RuntimeError(f"{meminfo_path} doesn't exit.")
meminfo = {}
with open(meminfo_path) as f:
for line in f:
# Each line looks like: "Node 0 MemTotal: 97421888 kB"
parts = line.split()
key = parts[2].rstrip(":")
# convert to Bytes
value = int(parts[3]) * 1024
meminfo[key] = value
total_memory = meminfo["MemTotal"]
free_memory = meminfo["MemFree"]
active_file_memory = meminfo["Active(file)"]
inactive_file_memory = meminfo["Inactive(file)"]
reclaimable_memory = meminfo["SReclaimable"]
available_memory = (
free_memory + active_file_memory + inactive_file_memory + reclaimable_memory
)
return MemoryNodeInfo(
total_memory=total_memory,
available_memory=available_memory,
)
def get_allowed_cpu_list() -> list[LogicalCPUInfo]:
cpu_list = _get_cpu_list()
if platform.system() == "Darwin":
return cpu_list
global_allowed_cpu_id_list = os.sched_getaffinity(0)
logical_cpu_list = [x for x in cpu_list if x.id in global_allowed_cpu_id_list]
return logical_cpu_list
def get_visible_memory_node() -> list[int]:
if platform.system() == "Darwin":
return [0]
allowed_memory_node_list = get_memory_affinity()
env_key = DEVICE_CONTROL_ENV_VAR
if (
("VLLM_CPU_SIM_MULTI_NUMA" not in os.environ)
and env_key in os.environ
and os.environ[env_key] != ""
):
visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
visible_nodes = [
node for node in visible_nodes if node in allowed_memory_node_list
]
return visible_nodes
return allowed_memory_node_list
@cache
def _get_cpu_list() -> list[LogicalCPUInfo]:
if platform.system() == "Darwin":
# For MacOS, no user-level CPU affinity and SMT, return all CPUs
cpu_count = os.cpu_count()
assert cpu_count
return [LogicalCPUInfo(i, i, 0) for i in range(cpu_count)]
lscpu_output = subprocess.check_output(
"lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
)
# For platform without NUMA, replace '-' to '0'
lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output)
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
lscpu_output, object_hook=LogicalCPUInfo.json_decoder
)["cpus"]
# Filter CPUs with invalid attributes
logical_cpu_list = [
x for x in logical_cpu_list if -1 not in (x.id, x.physical_core, x.numa_node)
]
return logical_cpu_list
......@@ -5,196 +5,280 @@ Copyright (c) 2026 Red Hat Inc
Copyright (c) 2026 Cambridge Greys Ltd
"""
import json
import os
import platform
import subprocess
from collections.abc import Callable
from contextlib import contextmanager
from typing import TYPE_CHECKING
import vllm.utils.cpu_resource_utils as cr_utils
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.cpu_resource_utils import LogicalCPUInfo
def _int(arg):
"""Relaxed parsing of ints which handles a - instead of a number.
The lscpu json may contain that for nodes in some cases. If that
is the case we parse it to zero
"""
try:
if int(arg) >= 0:
return int(arg)
except ValueError:
pass
return 0
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
def parse_mask(mask):
"""Expand a X-Y,Z list"""
result = []
for token in mask.split(","):
try:
start, finish = token.split("-")
if int(start) > int(finish):
raise IndexError("Invalid Indexes for cpu ranges")
for cpu in range(int(start), int(finish) + 1):
result.append(cpu)
except ValueError:
result.append(int(token))
return set(result)
def _get_default_affinity() -> set[int]:
"""Get the set of CPUs the process is allowed to run on."""
if hasattr(os, "sched_getaffinity"):
return os.sched_getaffinity(0)
# macOS does not support sched_getaffinity; fall back to cpu_count
cpu_count = os.cpu_count() or 1
return set(range(cpu_count))
def _get_cpu_topology_json() -> bytes:
"""Get CPU topology as JSON.
On Linux this uses ``lscpu -Je``. On other platforms (e.g. macOS) we
synthesize a simple topology where every logical CPU is its own core
on NUMA node 0, which is sufficient for the OMP place-list builder.
"""
if platform.system() == "Linux":
return subprocess.run(["lscpu", "-Je"], check=True, capture_output=True).stdout
# Fallback for non-Linux (macOS, etc.)
cpu_count = os.cpu_count() or 1
cpus = []
for i in range(cpu_count):
cpus.append({"cpu": str(i), "core": str(i), "node": "0"})
return json.dumps({"cpus": cpus}).encode()
class OMPProcessManager:
def __init__(self, config: "VllmConfig"):
if not current_platform.is_cpu():
return
self.local_world_size = config.parallel_config.local_world_size
self.local_dp_rank = config.parallel_config.data_parallel_rank_local
# This is a bit tricky because the internal DP size
# is always 1 for non-MoE models
self.internal_dp_size = config.parallel_config._api_process_count
def enumerate_resources(resource_map, mask=None, allowed=None):
"""Enumerate system resources"""
if allowed is None:
allowed = _get_default_affinity()
if mask is not None:
allowed = allowed & mask
self.simulate_multi_node = os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", "0") != "0"
ld_preload_str = os.getenv("LD_PRELOAD", "")
self.use_iomp = "libiomp" in ld_preload_str or "libomp" in ld_preload_str
self.use_gomp = "libgomp" in ld_preload_str
try:
allowed_nodes = parse_mask(os.environ["CPU_VISIBLE_MEMORY_NODES"])
except KeyError:
allowed_nodes = None
lscpu: dict[str, dict] = {"cpus": {}, "cores": {}, "nodes": {}}
for cpu in resource_map["cpus"]:
cpunum = int(cpu["cpu"])
if (
cpunum in allowed
and cpunum >= 0
and (allowed_nodes is None or _int(cpu["node"]) in allowed_nodes)
):
lscpu["cpus"][cpunum] = [cpu]
core = _int(cpu["core"])
if lscpu["cores"].get(core, None) is None:
lscpu["cores"][core] = [cpu]
assert not (self.use_iomp and self.use_gomp)
# at least reserve 1/local_world_size(for ARM) core for scheduler
# proc as always use MP executor
# TODO: make scheduler proc sleep when idle
self.reserve_cpu_num = (
self.local_world_size
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM
else 1
)
# reserve at one more core for nixl_connector under p/d case
if config.kv_transfer_config:
self.reserve_cpu_num += 1
if envs.VLLM_CPU_NUM_OF_RESERVED_CPU is not None:
if self.reserve_cpu_num > envs.VLLM_CPU_NUM_OF_RESERVED_CPU:
msg = (
f"VLLM_CPU_NUM_OF_RESERVED_CPU is less than "
"the minimum requirement"
f": {self.reserve_cpu_num} cores"
)
logger.warning(msg=msg)
self.reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
self._parse_omp_threads_bind_env()
assert not self.simulate_multi_node or self.auto_setup
@contextmanager
def configure_omp_envs(self, rank: int, local_rank: int):
if not current_platform.is_cpu() or self.skip_setup:
yield
return
envs_dict = {}
cpu_list = [str(i) for i in self.cpu_lists[local_rank]]
envs_dict["OMP_NUM_THREADS"] = str(len(cpu_list))
if self.use_iomp:
# set IOMP envs
cpu_list_str = ",".join(cpu_list)
envs_dict["KMP_AFFINITY"] = (
f"granularity=fine,explicit,proclist=[{cpu_list_str}]"
)
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
envs_dict["KMP_BLOCKTIME"] = "1"
# Prevents the CPU to run into low performance state
envs_dict["KMP_TPAUSE"] = "0"
# Provides fine granularity parallelism
envs_dict["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
envs_dict["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
envs_dict["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
elif self.use_gomp:
# set GOMP envs
# likes '0 1 2 ...'
cpu_list_str = " ".join(cpu_list)
envs_dict["GOMP_CPU_AFFINITY"] = cpu_list_str
else:
lscpu["cores"][core].append(cpu)
node = _int(cpu["node"])
if lscpu["nodes"].get(node, None) is None:
lscpu["nodes"][node] = [cpu]
# set OMP envs
# likes '{0,1,2,...}'
cpu_list_str = ",".join(cpu_list)
envs_dict["OMP_PLACES"] = f"{{{cpu_list_str}}}"
envs_dict["OMP_PROC_BIND"] = "true"
# backup envs
old_envs_dict = {}
for k in envs_dict:
old_envs_dict[k] = os.environ.get(k)
try:
# set envs
for k, v in envs_dict.items():
os.environ[k] = v
yield
finally:
# restore old envs
for k, v in old_envs_dict.items(): # type: ignore
if v is None:
os.environ.pop(k, None)
else:
lscpu["nodes"][node].append(cpu)
return lscpu
def produce_cpu_list(cpus, smt=1):
"""Produce a CPU list with/without SMT pairs - main cpu list case"""
mask: list[int] = []
for key, value in cpus.items():
exists = 0
for cpu in mask:
if cpu == value[0]["core"]:
exists += 1
break
if exists < smt:
mask.append(int(key))
return {"mask": set(mask), "available": True}
def produce_cpu_sublist(scpus, smt=1):
"""Produce a CPU list with/without SMT pairs - resource leaf case"""
cpu_list: list[dict] = []
for value in scpus:
exists = 0
for cpu in cpu_list:
if int(cpu["core"]) == int(value["core"]):
exists += 1
break
if exists < smt:
cpu_list.append(value)
mask = []
for cpu in cpu_list:
mask.append(int(cpu["cpu"]))
return {"mask": set(mask), "available": True}
def create_omp_places(resources, strategy, smt=True):
"""Parse CPU topology and generate possible CPU masks"""
omp_places = []
if strategy == "all":
omp_places.append(produce_cpu_list(resources["cpus"], smt))
elif strategy == "cores":
for value in resources["cores"].values():
omp_places.append(produce_cpu_sublist(value, smt))
elif strategy == "nodes":
for value in resources["nodes"].values():
omp_places.append(produce_cpu_sublist(value, smt))
os.environ[k] = v
def _parse_omp_threads_bind_env(self):
vllm_mask = envs.VLLM_CPU_OMP_THREADS_BIND
self.skip_setup = vllm_mask == "nobind"
self.auto_setup = vllm_mask == "auto"
self.reserved_cpu_list = []
self.cpu_lists = []
if self.auto_setup:
# auto generate CPU lists
cpu_arch = current_platform.get_cpu_architecture()
if cpu_arch == CpuArchEnum.POWERPC:
# For POWERPC SMT-8/4/2
cpu_list, reserve_list = self._get_autobind_cpu_ids(
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
)
elif cpu_arch in (CpuArchEnum.X86, CpuArchEnum.S390X):
# For x86/S390X SMT-2, use 1 logical CPU per physical core
cpu_list, reserve_list = self._get_autobind_cpu_ids(
lambda cpus: cpus[-1:]
)
elif cpu_arch == CpuArchEnum.ARM:
# For AArch64, no SMT, use all logical CPU
cpu_list, reserve_list = self._get_autobind_cpu_ids(lambda cpus: cpus)
else:
raise NotImplementedError("Unknown strategy")
cpu_list, reserve_list = [], []
raise RuntimeError(f"{cpu_arch} doesn't support auto CPU binding.")
return omp_places
for item in cpu_list:
self.cpu_lists.append([x.id for x in item])
self.reserved_cpu_list = [x.id for x in reserve_list]
elif not self.skip_setup:
# user defined CPU lists
omp_cpuids_list = vllm_mask.split("|")
if self.local_dp_rank is not None:
local_dp_rank = self.local_dp_rank
world_size = self.local_world_size
# Rank mapping [DP, PP, TP]
omp_cpuids_list = omp_cpuids_list[
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
]
assert len(omp_cpuids_list) == self.local_world_size, (
"Given "
f"number of CPU id list {omp_cpuids_list} doesn't match "
f"local world size {self.local_world_size}."
)
# pylint: disable=too-few-public-methods
class OMPProcessManager:
"""OMP aware wrapper to run mp Process()"""
def __init__(self, strategy="nodes", smt=1, mock=None, affinity=None):
self.strategy = strategy
self.smt = smt
self.omp_places = []
vllm_mask = os.environ.get("VLLM_CPU_OMP_THREADS_BIND", None)
self.setup_omp = vllm_mask != "nobind"
if self.setup_omp:
omp_places = []
if vllm_mask is not None:
masks = []
for spec in vllm_mask.split("|"):
masks.append(parse_mask(spec))
# parse CPU list strings like "5,2-4" to [5, 2, 3, 4]
self.cpu_lists = [cr_utils.parse_id_list(s) for s in omp_cpuids_list]
else:
masks = [None]
if mock is None:
data = _get_cpu_topology_json()
# skip
self.cpu_lists = []
msg = "OpenMP thread binding info: \n"
for i in range(self.local_world_size):
msg += f"\tlocal_rank={i}, core ids={self.cpu_lists[i]}\n"
msg += f"\treserved_cpus={self.reserved_cpu_list}"
logger.info(msg)
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
) -> tuple[list[list[LogicalCPUInfo]], list[LogicalCPUInfo]]:
"""
Return CPU ids to bind based on NUMA nodes, and CPU ids reserved for
other processes.
Currently for rank N, only CPU ids on the N-th node in available NUMA
node list will be selected.
Args:
cpu_selector: a callable object to select CPUs from a CPU list
of a physical core. The input is a LogicalCPUInfo list contains
logical CPUs of a physical CPU, sorted by the LogicalCPUInfo.id.
A selected LogicalCPUInfo list should be returned.
"""
# this memory node list has been sliced for DP offset
allowed_numa_nodes = cr_utils.get_visible_memory_node()
logical_cpu_list = cr_utils.get_allowed_cpu_list()
local_world_size = self.local_world_size
assert (
len(allowed_numa_nodes) >= local_world_size or self.simulate_multi_node
), (
f"Not enough allowed NUMA nodes to bind threads of "
f"{local_world_size} local CPUWorkers. "
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
"Please try to bind threads manually or decrease DP/TP/PP."
)
# Generate OMP CPU list for each rank
cpu_lists_of_ranks = []
reserved_cpu_list = []
total_cpu_num = 0
for local_rank in range(self.local_world_size):
if not self.simulate_multi_node:
selected_numa_node = allowed_numa_nodes[local_rank]
selected_logical_cpu_list = [
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
else:
with open(mock, mode="rb") as jf:
data = jf.read()
lscpu = json.loads(data)
for mask in masks:
resources = enumerate_resources(lscpu, mask, affinity)
omp_places.extend(create_omp_places(resources, strategy, smt))
self.omp_places = sorted(
omp_places,
key=lambda p: "{:04d}-{:04d}".format(len(p["mask"]), max(p["mask"])),
reverse=True,
world_size_across_dp = self.local_world_size * self.internal_dp_size
assert len(logical_cpu_list) >= world_size_across_dp
selected_logical_cpu_list = sorted(
logical_cpu_list, key=lambda x: x.numa_node
)
sim_cpu_num_per_node = (
len(selected_logical_cpu_list) // world_size_across_dp
)
assert self.local_dp_rank is not None
start_idx = (
local_rank + self.local_world_size * self.local_dp_rank
) * sim_cpu_num_per_node
selected_logical_cpu_list = selected_logical_cpu_list[
start_idx : (start_idx + sim_cpu_num_per_node)
]
# Select logical CPUs on same physical cores via cpu_selector
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
for cpu_info in selected_logical_cpu_list:
if cpu_info.physical_core not in core_to_cpus:
core_to_cpus[cpu_info.physical_core] = []
core_to_cpus[cpu_info.physical_core].append(cpu_info)
selected_logical_cpu_list = []
for cpu_list in core_to_cpus.values():
cpu_list = sorted(cpu_list, key=lambda x: x.id)
selected_logical_cpu_list.extend(cpu_selector(cpu_list))
# sort selected cores based on core id
selected_logical_cpu_list = sorted(
selected_logical_cpu_list, key=lambda x: x.id
)
def run(self, what, *args, **kwargs):
"""Run arg with correct OMP environment"""
if self.setup_omp:
for place in self.omp_places:
if place["available"]:
reserve = int(os.environ.get("VLLM_CPU_NUM_OF_RESERVED_CPU", 0))
place["available"] = False
# pylint: disable=consider-using-f-string
os.environ["OMP_PLACES"] = "{}".format(place["mask"])
os.environ["OMP_NUM_THREADS"] = "{}".format(
len(place["mask"]) - reserve
cpu_lists_of_ranks.append(selected_logical_cpu_list)
total_cpu_num += len(selected_logical_cpu_list)
# Reserve CPUs for other processes
if total_cpu_num <= self.reserve_cpu_num:
logger.warning(
"Selected CPU core number (%s) "
"should be greater than reserved CPU core "
"number (%s).",
total_cpu_num,
self.reserve_cpu_num,
)
os.environ["OMP_PROC_BIND"] = "TRUE"
return what(*args, **kwargs)
raise IndexError("Out of OMP places")
return what(*args, **kwargs)
return cpu_lists_of_ranks, []
reserve_num_per_rank = [
self.reserve_cpu_num // self.local_world_size
] * self.local_world_size
# last rank first
for i in range(
self.local_world_size - 1,
self.local_world_size - 1 - self.reserve_cpu_num % self.local_world_size,
-1,
):
reserve_num_per_rank[i] += 1
for i in range(self.local_world_size):
num = reserve_num_per_rank[i]
if num > 0:
reserved_cpu_list.extend(cpu_lists_of_ranks[i][-num:])
cpu_lists_of_ranks[i] = cpu_lists_of_ranks[i][:-num]
return cpu_lists_of_ranks, reserved_cpu_list
......@@ -51,6 +51,7 @@ from vllm.utils.network_utils import (
get_loopback_ip,
get_open_port,
)
from vllm.utils.ompmultiprocessing import OMPProcessManager
from vllm.utils.system_utils import (
_maybe_force_spawn,
decorate_logs,
......@@ -169,24 +170,14 @@ class MultiprocExecutor(Executor):
[] if context.get_start_method() == "fork" else None
)
# For CPU backend only, to setup OpenMP threads affinity
cpu_omp_manager = OMPProcessManager(self.vllm_config)
for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank
is_driver_worker = self._is_driver_worker(global_rank)
if current_platform.is_cpu():
om = current_platform.get_omp_manager()
logger.info("Configured OMP PLACES %s", str(om.omp_places))
unready_worker_handle = om.run(
WorkerProc.make_worker_process,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker,
inherited_fds=inherited_fds,
)
else:
with cpu_omp_manager.configure_omp_envs(
rank=global_rank, local_rank=local_rank
):
unready_worker_handle = WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=local_rank,
......
......@@ -116,21 +116,7 @@ class CPUModelRunner(GPUModelRunner):
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
with _set_global_compilation_settings(self.vllm_config):
self._dummy_run(
min(
max(16, self.max_num_reqs),
self.scheduler_config.max_num_batched_tokens,
)
)
# Warm up drafter for speculative decoding
if self.speculative_config and (self.speculative_config.uses_draft_model()):
from vllm.v1.spec_decode.draft_model import DraftModelProposer
if isinstance(self.drafter, (DraftModelProposer)):
logger.info("Warming up drafter model...")
self.drafter.dummy_run(max(16, self.max_num_reqs))
self.profile_run()
logger.info("Warming up done.")
def initialize_kv_cache(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import os
import sys
from typing import Any
import psutil
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.cpu_resource_utils import (
get_allowed_cpu_list,
get_memory_node_info,
get_visible_memory_node,
)
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
......@@ -27,6 +35,46 @@ class CPUWorker(Worker):
distributed_init_method: str,
is_driver_worker: bool = False,
):
# TODO: use numactl for process setup
# TODO: optimize for `interleaved` policy
# Bind memory node
allowed_memory_nodes = get_visible_memory_node()
allowed_cpu_list = get_allowed_cpu_list()
cpu_core = allowed_cpu_list[0]
# TODO: some CI hosts are not correctly set, change to assertion
# after fix
if cpu_core.numa_node not in allowed_memory_nodes:
logger.warning(
"Node %s is not in available memory nodes %s.",
cpu_core.numa_node,
allowed_memory_nodes,
)
torch.ops._C.init_cpu_memory_env([cpu_core.numa_node])
memory_status = get_memory_node_info(cpu_core.numa_node)
memory_fraction = vllm_config.cache_config.gpu_memory_utilization
self.requested_cpu_memory = math.ceil(
memory_status.total_memory * memory_fraction
)
available_memory = memory_status.available_memory
if (
vllm_config.cache_config.kv_cache_memory_bytes is None
and self.requested_cpu_memory > available_memory
):
raise ValueError(
f"Available memory on node {cpu_core.numa_node} "
f"({format_gib(available_memory)}/"
f"{format_gib(memory_status.total_memory)} GiB) on startup "
f"is less than desired CPU memory utilization "
f"({vllm_config.cache_config.gpu_memory_utilization}, "
f"{format_gib(self.requested_cpu_memory)} GiB). "
"Decrease --gpu-memory-utilization"
f" or reduce CPU memory used by other processes."
)
super().__init__(
vllm_config,
local_rank,
......@@ -103,13 +151,69 @@ class CPUWorker(Worker):
pass
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
self.model_runner.warming_up_model()
allowed_cpu_list = get_allowed_cpu_list()
cpu_core = allowed_cpu_list[0]
memory_status = get_memory_node_info(cpu_core.numa_node)
available_memory = memory_status.available_memory
explicit_kv_cache_size = self.cache_config.kv_cache_memory_bytes
kv_cache_size = None
msg = None
if explicit_kv_cache_size is not None:
if explicit_kv_cache_size > available_memory:
raise ValueError(
f"Available memory on node {cpu_core.numa_node} "
f"({format_gib(available_memory)}/"
f"{format_gib(memory_status.total_memory)} GiB) on kv cache"
f" allocation is less than requested memory for kv "
f"({format_gib(explicit_kv_cache_size)} GiB). "
"Decrease --kv-cache-memory-bytes, VLLM_CPU_KVCACHE_SPACE, "
"or reduce CPU memory used by other processes."
)
kv_cache_size = explicit_kv_cache_size
msg = (
f"Explicitly set ({format_gib(kv_cache_size)}/"
f"{format_gib(memory_status.total_memory)}) GiB for KV cache "
f"on node {cpu_core.numa_node}."
)
else:
consumed_memory = psutil.Process(os.getpid()).memory_info().rss
requested_memory_for_kv = int(self.requested_cpu_memory - consumed_memory)
if (
requested_memory_for_kv <= 0
or requested_memory_for_kv > available_memory
):
raise ValueError(
f"Available memory on node {cpu_core.numa_node} "
f"({format_gib(available_memory)}/"
f"{format_gib(memory_status.total_memory)} GiB) on kv cache"
f" allocation is less than requested memory for kv "
f"({format_gib(requested_memory_for_kv)}/"
f"{format_gib(self.requested_cpu_memory)} GiB). "
"Reduce CPU memory used by other processes."
)
kv_cache_size = requested_memory_for_kv
msg = (
f"Auto set ({format_gib(kv_cache_size)}/"
f"{format_gib(memory_status.total_memory)}) GiB for KV cache "
f"on node {cpu_core.numa_node}, with "
f"{format_gib(self.requested_cpu_memory)} GiB requested memory"
f" for the worker. {format_gib(consumed_memory)} GiB"
f" memory was consumed by non-kv usages."
)
logger.info(msg)
return kv_cache_size
def compile_or_warm_up_model(self) -> CompilationTimes:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
# Note: the model has been compiled in determine_available_memory()
return CompilationTimes(
language_model=self.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment