Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c5131f7a
Unverified
Commit
c5131f7a
authored
Jun 30, 2025
by
Chunyuan WU
Committed by
GitHub
Jun 29, 2025
Browse files
[CPU] add c++ kernel to bind CPU cores and memory node (#7524)
parent
78700893
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
159 additions
and
8 deletions
+159
-8
docker/Dockerfile.xeon
docker/Dockerfile.xeon
+8
-4
python/pyproject.toml
python/pyproject.toml
+1
-3
sgl-kernel/csrc/cpu/CMakeLists.txt
sgl-kernel/csrc/cpu/CMakeLists.txt
+19
-1
sgl-kernel/csrc/cpu/numa_utils.cpp
sgl-kernel/csrc/cpu/numa_utils.cpp
+91
-0
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+10
-0
test/srt/cpu/test_binding.py
test/srt/cpu/test_binding.py
+28
-0
test/srt/run_suite.py
test/srt/run_suite.py
+2
-0
No files found.
docker/Dockerfile.xeon
View file @
c5131f7a
...
...
@@ -2,8 +2,9 @@ FROM ubuntu:24.04
SHELL ["/bin/bash", "-c"]
ARG VER_SGLANG=main
ARG VER_TORCH=2.6.0
ARG VER_TORCHVISION=0.21.0
ARG VER_TORCH=2.7.1
ARG VER_TORCHVISION=0.22.1
ARG VER_TRITON=3.3.1
RUN apt-get update && \
apt-get full-upgrade -y && \
...
...
@@ -27,14 +28,17 @@ RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/re
ENV PATH=/sgl-workspace/miniforge3/bin:/sgl-workspace/miniforge3/condabin:${PATH}
ENV PIP_ROOT_USER_ACTION=ignore
ENV CONDA_PREFIX=/sgl-workspace/miniforge3
RUN pip install intel-openmp
RUN pip config set global.index-url https://download.pytorch.org/whl/cpu && \
pip config set global.extra-index-url https://pypi.org/simple && \
pip install intel-openmp
RUN git clone https://github.com/sgl-project/sglang.git && \
cd sglang && \
git checkout ${VER_SGLANG} && \
pip install -e "python[all_cpu]" && \
pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION}
--index-url https://download.pytorch.org/whl/cpu
--force-reinstall && \
pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION}
triton==${VER_TRITON}
--force-reinstall && \
cd sgl-kernel && \
cp pyproject_cpu.toml pyproject.toml && \
pip install -v .
...
...
python/pyproject.toml
View file @
c5131f7a
...
...
@@ -88,9 +88,7 @@ srt_xpu = ["sglang[runtime_common]"]
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu
=
["sglang[runtime_common]"]
# CPU: currently, there are no pre-built vllm wheels for CPU.
# To install vllm for CPU, please follow the instruction here:
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
# CPU: torch wheel for CPU needs to be installed from https://download.pytorch.org/whl/cpu
srt_cpu
=
["sglang[runtime_common]
", "
einops
"]
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
srt_npu
=
["sglang[runtime_common]"]
...
...
sgl-kernel/csrc/cpu/CMakeLists.txt
View file @
c5131f7a
...
...
@@ -38,6 +38,24 @@ else()
endif
()
link_directories
(
${
PLAT_LIB_DIR
}
)
# Conda library path support
if
(
DEFINED ENV{CONDA_PREFIX}
)
set
(
CONDA_LIB_DIR
"$ENV{CONDA_PREFIX}/lib"
)
message
(
STATUS
"Using Conda lib dir:
${
CONDA_LIB_DIR
}
"
)
link_directories
(
${
CONDA_LIB_DIR
}
)
set
(
CONDA_INCLUDE_DIR
"$ENV{CONDA_PREFIX}/include"
)
include_directories
(
${
CONDA_INCLUDE_DIR
}
)
# Look for libnuma in Conda's lib directory
find_library
(
NUMA_LIB numa HINTS
"
${
CONDA_LIB_DIR
}
"
)
if
(
NUMA_LIB
)
message
(
STATUS
"Found libnuma:
${
NUMA_LIB
}
"
)
else
()
message
(
FATAL_ERROR
"libnuma not found in Conda environment at
${
CONDA_LIB_DIR
}
\n
"
"Please install it using: conda install libnuma numactl
\n
"
)
endif
()
endif
()
file
(
GLOB SOURCES
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp"
)
add_compile_options
(
...
...
@@ -48,7 +66,7 @@ add_compile_options(
)
Python_add_library
(
common_ops MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SOURCES
}
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
${
NUMA_LIB
}
)
target_include_directories
(
common_ops PRIVATE
${
TORCH_INCLUDE_DIRS
}
)
install
(
TARGETS common_ops
...
...
sgl-kernel/csrc/cpu/numa_utils.cpp
0 → 100644
View file @
c5131f7a
#include <numa.h>
#include <sched.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
#include <string>
#include "common.h"
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
bitmask
*
omp_cpu_mask
=
numa_parse_cpustring
(
cpu_ids
.
c_str
());
TORCH_CHECK
(
omp_cpu_mask
->
size
>
0
);
std
::
vector
<
int
>
omp_cpu_ids
;
omp_cpu_ids
.
reserve
(
omp_cpu_mask
->
size
);
constexpr
int
group_size
=
8
*
sizeof
(
*
omp_cpu_mask
->
maskp
);
for
(
int
offset
=
0
;
offset
<
omp_cpu_mask
->
size
;
offset
+=
group_size
)
{
unsigned
long
group_mask
=
omp_cpu_mask
->
maskp
[
offset
/
group_size
];
int
i
=
0
;
while
(
group_mask
)
{
if
(
group_mask
&
1
)
{
omp_cpu_ids
.
emplace_back
(
offset
+
i
);
}
++
i
;
group_mask
>>=
1
;
}
}
// Memory node binding
if
(
numa_available
()
!=
-
1
)
{
int
mem_node_id
=
numa_node_of_cpu
(
omp_cpu_ids
.
front
());
bitmask
*
mask
=
numa_parse_nodestring
(
std
::
to_string
(
mem_node_id
).
c_str
());
bitmask
*
src_mask
=
numa_get_membind
();
int
pid
=
getpid
();
// move all existing pages to the specified numa node.
*
(
src_mask
->
maskp
)
=
*
(
src_mask
->
maskp
)
^
*
(
mask
->
maskp
);
int
page_num
=
numa_migrate_pages
(
pid
,
src_mask
,
mask
);
if
(
page_num
==
-
1
)
{
TORCH_WARN
(
false
,
"numa_migrate_pages failed. errno: "
+
std
::
to_string
(
errno
));
}
// restrict memory allocation node.
numa_set_membind
(
mask
);
numa_set_strict
(
1
);
}
// OMP threads binding
omp_set_num_threads
((
int
)
omp_cpu_ids
.
size
());
at
::
set_num_threads
((
int
)
omp_cpu_ids
.
size
());
TORCH_CHECK_EQ
(
omp_cpu_ids
.
size
(),
at
::
get_num_threads
());
TORCH_CHECK_EQ
(
omp_cpu_ids
.
size
(),
omp_get_max_threads
());
std
::
vector
<
std
::
pair
<
int
,
int
>>
thread_core_mapping
;
thread_core_mapping
.
reserve
(
omp_cpu_ids
.
size
());
omp_lock_t
writelock
;
omp_init_lock
(
&
writelock
);
#pragma omp parallel for schedule(static, 1)
for
(
size_t
i
=
0
;
i
<
omp_cpu_ids
.
size
();
++
i
)
{
cpu_set_t
mask
;
CPU_ZERO
(
&
mask
);
CPU_SET
(
omp_cpu_ids
[
i
],
&
mask
);
int
ret
=
sched_setaffinity
(
0
,
sizeof
(
cpu_set_t
),
&
mask
);
if
(
ret
==
-
1
)
{
TORCH_CHECK
(
false
,
"sched_setaffinity failed. errno: "
+
std
::
to_string
(
errno
));
}
omp_set_lock
(
&
writelock
);
thread_core_mapping
.
emplace_back
(
syscall
(
SYS_gettid
),
omp_cpu_ids
[
i
]);
omp_unset_lock
(
&
writelock
);
}
omp_destroy_lock
(
&
writelock
);
numa_free_nodemask
(
omp_cpu_mask
);
std
::
stringstream
ss
;
ss
<<
"OMP threads binding of Process "
<<
getpid
()
<<
":
\n
"
;
std
::
sort
(
thread_core_mapping
.
begin
(),
thread_core_mapping
.
end
(),
[](
auto
&&
a
,
auto
&&
b
)
{
return
a
.
second
<
b
.
second
;
});
for
(
auto
&&
item
:
thread_core_mapping
)
{
ss
<<
"
\t
"
<<
"OMP tid: "
<<
item
.
first
<<
", core "
<<
item
.
second
<<
"
\n
"
;
}
return
ss
.
str
();
}
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
c5131f7a
...
...
@@ -227,6 +227,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
// CPU and memory binding
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
);
TORCH_LIBRARY_FRAGMENT
(
sgl_kernel
,
m
)
{
// activation
m
.
def
(
"silu_and_mul_cpu(Tensor input) -> Tensor"
);
...
...
@@ -353,6 +356,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)"
);
m
.
impl
(
"rotary_embedding_cpu"
,
torch
::
kCPU
,
&
rotary_embedding_cpu
);
// CPU and memory binding
m
.
def
(
"init_cpu_threads_env(str cpu_ids) -> str"
);
}
TORCH_LIBRARY_IMPL
(
sgl_kernel
,
CatchAll
,
m
)
{
m
.
impl
(
"init_cpu_threads_env"
,
init_cpu_threads_env
);
}
REGISTER_EXTENSION
(
common_ops
)
test/srt/cpu/test_binding.py
0 → 100644
View file @
c5131f7a
import
re
import
unittest
import
sgl_kernel
import
torch
kernel
=
torch
.
ops
.
sgl_kernel
from
sglang.test.test_utils
import
CustomTestCase
class
TestGemm
(
CustomTestCase
):
def
test_binding
(
self
):
start_id
=
1
n_cpu
=
6
expected_cores
=
list
(
map
(
str
,
range
(
start_id
,
start_id
+
n_cpu
)))
cpu_ids
=
","
.
join
(
expected_cores
)
output
=
kernel
.
init_cpu_threads_env
(
cpu_ids
)
bindings
=
re
.
findall
(
r
"OMP tid: \d+, core (\d+)"
,
output
)
self
.
assertEqual
(
len
(
bindings
),
n_cpu
)
self
.
assertEqual
(
bindings
,
expected_cores
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
c5131f7a
...
...
@@ -183,6 +183,7 @@ suites = {
],
"per-commit-cpu"
:
[
TestFile
(
"cpu/test_activation.py"
),
TestFile
(
"cpu/test_binding.py"
),
TestFile
(
"cpu/test_decode.py"
),
TestFile
(
"cpu/test_extend.py"
),
TestFile
(
"cpu/test_gemm.py"
),
...
...
@@ -192,6 +193,7 @@ suites = {
TestFile
(
"cpu/test_qkv_proj_with_rope.py"
),
TestFile
(
"cpu/test_rope.py"
),
TestFile
(
"cpu/test_shared_expert.py"
),
TestFile
(
"cpu/test_topk.py"
),
],
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment