Unverified Commit 10189d08 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance]: Process affinity to CPU cores with multiple sockets support (#2171)

parent c4336b2b
...@@ -72,6 +72,7 @@ from sglang.srt.utils import ( ...@@ -72,6 +72,7 @@ from sglang.srt.utils import (
configure_logger, configure_logger,
crash_on_warnings, crash_on_warnings,
get_zmq_socket, get_zmq_socket,
gpu_proc_affinity,
kill_parent_process, kill_parent_process,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -1393,6 +1394,9 @@ def run_scheduler_process( ...@@ -1393,6 +1394,9 @@ def run_scheduler_process(
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
): ):
# set cpu affinity to this gpu process
gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "DP_RANK" in os.environ: if dp_rank is None and "DP_RANK" in os.environ:
dp_rank = int(os.environ["DP_RANK"]) dp_rank = int(os.environ["DP_RANK"])
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import base64 import base64
import ipaddress import ipaddress
import itertools
import json import json
import logging import logging
import os import os
...@@ -987,3 +988,37 @@ def direct_register_custom_op( ...@@ -987,3 +988,37 @@ def direct_register_custom_op(
my_lib.impl(op_name, op_func, "CUDA") my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
def gpu_proc_affinity(
tp_size: int,
nnodes: int,
gpu_id: int,
):
# current process
pid = os.getpid()
p = psutil.Process(pid)
tp_size_per_node = tp_size // nnodes
# total physical cores
total_pcores = psutil.cpu_count(logical=False)
# physical cores per TP (N.B. more Cores than GPUs on node)
num_cores_bind = total_pcores // tp_size_per_node
# able to handle multiple DP per node
start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
end_cpu_id = start_cpu_id + num_cores_bind
if psutil.cpu_count() != psutil.cpu_count(logical=False):
# HT on
upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids))
else:
# HT off
bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
# set cpu_affinity to current process
p.cpu_affinity(bind_cpu_ids)
logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment