Unverified Commit 40e5cb7a authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] Bind threads and numa node for each TP rank (#6549)


Co-authored-by: default avatarsrinarayan-srikanthan <srinarayan.srikanthan@intel.com>
parent 8e64140e
......@@ -102,6 +102,7 @@ from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
get_bool_env_var,
get_cpu_ids_by_node,
init_custom_process_group,
is_cuda,
is_fa3_default_architecture,
......@@ -211,6 +212,10 @@ class ModelRunner:
# CPU offload
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Init OpenMP threads binding for CPU
if self.device == "cpu":
self.init_threads_binding()
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
......@@ -497,6 +502,15 @@ class ModelRunner:
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
if not self.is_draft_worker:
if self.device == "cpu":
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
else:
logger.warning(
"init_cpu_threads_env is skipped since intel amx backend is not available"
)
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
......@@ -1308,6 +1322,30 @@ class ModelRunner:
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
)
def init_threads_binding(self):
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
if omp_cpuids == "all":
cpu_ids_by_node = get_cpu_ids_by_node()
n_numa_node = len(cpu_ids_by_node)
assert self.tp_size <= n_numa_node, (
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
)
if self.tp_size < n_numa_node:
logger.warning(
f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
)
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
......
......@@ -40,6 +40,7 @@ import threading
import time
import traceback
import warnings
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
......@@ -2545,3 +2546,69 @@ def align(x: int, y: int) -> int:
# COPIED FROM DeepGEMM
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def parse_lscpu_topology():
try:
# Get CPU topology: CPU,Core,Socket,Node
output = subprocess.check_output(
["lscpu", "-p=CPU,Core,Socket,Node"], text=True
)
except Exception as e:
raise RuntimeError(f"Unexpected error running 'lscpu': {e}")
# Parse only data lines (skip comments)
cpu_info = []
for line in output.splitlines():
if not line.startswith("#"):
cpu, core, socket, node = map(int, line.strip().split(","))
cpu_info.append((cpu, core, socket, node))
# [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...]
return cpu_info
def get_physical_cpus_by_numa():
cpu_info = parse_lscpu_topology()
# Map NUMA node -> set of (core_id, socket) to avoid duplicates
# 0: {(0,0): 0, (1, 0): 1,...}
# ...
# 5: {(214,1): 214, (215,1): 215}
physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id
for cpu, core, socket, node in cpu_info:
key = (core, socket)
if key not in physical_by_node[node]:
physical_by_node[node][
key
] = cpu # pick first CPU seen for that physical core
# Retrieves CPUs that the current process is allowed to run on
cpus_allowed_list = psutil.Process().cpu_affinity()
# Convert to list of physical CPUs per node
# 0: [0,1,2,...,42]
# ...
# 2: [86,87,...,127]
# ...
# 5: [214,215,...,255]
node_to_cpus = {}
for node, core_to_cpu in physical_by_node.items():
cpus = sorted(core_to_cpu.values())
allowed_cpus = set(cpus).intersection(cpus_allowed_list)
node_to_cpus[node] = allowed_cpus
return node_to_cpus
# Only physical cores are used. Logical cores are excluded.
def get_cpu_ids_by_node():
node_to_cpus = get_physical_cpus_by_numa()
# Sort by NUMA node index
cpu_ids = [
",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus)
]
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment