Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c1a695f
Unverified
Commit
2c1a695f
authored
Feb 04, 2025
by
HAI
Committed by
GitHub
Feb 04, 2025
Browse files
ROCm: sgl-kernel enablement starting with sgl_moe_align_block (#3287)
parent
d39899e8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
131 additions
and
13 deletions
+131
-13
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+3
-0
docs/start/install.md
docs/start/install.md
+3
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+3
-11
sgl-kernel/setup_rocm.py
sgl-kernel/setup_rocm.py
+92
-0
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
+29
-0
No files found.
docker/Dockerfile.rocm
View file @
2c1a695f
...
@@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \
...
@@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \
echo "Using ${SGL_BRANCH} branch."; \
echo "Using ${SGL_BRANCH} branch."; \
git checkout ${SGL_BRANCH}; \
git checkout ${SGL_BRANCH}; \
fi \
fi \
&& cd sgl-kernel \
&& python setup_rocm.py install \
&& cd .. \
&& if [ "$BUILD_TYPE" = "srt" ]; then \
&& if [ "$BUILD_TYPE" = "srt" ]; then \
python -m pip --no-cache-dir install -e "python[srt_hip]"; \
python -m pip --no-cache-dir install -e "python[srt_hip]"; \
else \
else \
...
...
docs/start/install.md
View file @
2c1a695f
...
@@ -32,7 +32,9 @@ git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git
...
@@ -32,7 +32,9 @@ git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git
cd sglang
cd sglang
pip install --upgrade pip
pip install --upgrade pip
pip install sgl-kernel --force-reinstall --no-deps
cd sgl-kernel
python setup_rocm.py install
cd ..
pip install -e "python[all_hip]"
pip install -e "python[all_hip]"
```
```
...
...
python/pyproject.toml
View file @
2c1a695f
...
@@ -31,7 +31,7 @@ srt = [
...
@@ -31,7 +31,7 @@ srt = [
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.6.7
.dev
2
", "
outlines==
0.1.11
"]
srt_hip
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.6.7
.dev
2
", "
outlines==
0.1.11
"
, "
sgl-kernel>=
0.0.3
.post
1
"
]
# xpu is not enabled in public vllm and torch whl,
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu
=
["sglang[runtime_common]
", "
outlines>=
0.0.44
,
<
0.1.0
"]
srt_xpu
=
["sglang[runtime_common]
", "
outlines>=
0.0.44
,
<
0.1.0
"]
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
2c1a695f
...
@@ -15,18 +15,10 @@ from vllm import _custom_ops as ops
...
@@ -15,18 +15,10 @@ from vllm import _custom_ops as ops
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
,
is_hip
direct_register_custom_op
,
get_device_name
,
is_cuda_available
,
is_hip
,
)
is_cuda
=
is_cuda_available
()
is_hip_flag
=
is_hip
()
is_hip_flag
=
is_hip
()
if
is_cuda
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
...
@@ -415,7 +407,7 @@ def moe_align_block_size(
...
@@ -415,7 +407,7 @@ def moe_align_block_size(
)
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
num_experts
>=
224
:
if
enable_moe_align_block_size_triton
or
is_hip_flag
:
if
enable_moe_align_block_size_triton
:
moe_align_block_size_triton
(
moe_align_block_size_triton
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
...
...
sgl-kernel/setup_rocm.py
0 → 100644
View file @
2c1a695f
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
multiprocessing
import
os
import
sys
from
pathlib
import
Path
import
torch
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
root
=
Path
(
__file__
).
parent
.
resolve
()
if
"bdist_wheel"
in
sys
.
argv
and
"--plat-name"
not
in
sys
.
argv
:
sys
.
argv
.
extend
([
"--plat-name"
,
"manylinux2014_x86_64"
])
def
_get_version
():
with
open
(
root
/
"pyproject.toml"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"version"
):
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
operator_namespace
=
"sgl_kernels"
include_dirs
=
[
root
/
"src"
/
"sgl-kernel"
/
"include"
,
root
/
"src"
/
"sgl-kernel"
/
"csrc"
,
]
sources
=
[
"src/sgl-kernel/torch_extension_rocm.cc"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
]
cxx_flags
=
[
"-O3"
]
libraries
=
[
"hiprtc"
,
"amdhip64"
,
"c10"
,
"torch"
,
"torch_python"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
"-L/usr/lib/x86_64-linux-gnu"
]
hipcc_flags
=
[
"-DNDEBUG"
,
f
"-DOPERATOR_NAMESPACE=
{
operator_namespace
}
"
,
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-std=c++17"
,
"-D__HIP_PLATFORM_AMD__=1"
,
"--amdgpu-target=gfx942"
,
"-DENABLE_BF16"
,
"-DENABLE_FP8"
,
]
setup
(
name
=
"sgl-kernel"
,
version
=
_get_version
(),
packages
=
find_packages
(),
package_dir
=
{
""
:
"src"
},
ext_modules
=
[
CUDAExtension
(
name
=
"sgl_kernel.ops._kernels"
,
sources
=
sources
,
include_dirs
=
include_dirs
,
extra_compile_args
=
{
"nvcc"
:
hipcc_flags
,
"cxx"
:
cxx_flags
,
},
libraries
=
libraries
,
extra_link_args
=
extra_link_args
,
py_limited_api
=
True
,
),
],
cmdclass
=
{
"build_ext"
:
BuildExtension
.
with_options
(
use_ninja
=
True
,
max_jobs
=
multiprocessing
.
cpu_count
()
)
},
options
=
{
"bdist_wheel"
:
{
"py_limited_api"
:
"cp39"
}},
install_requires
=
[
"torch"
],
)
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
0 → 100644
View file @
2c1a695f
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernels
,
m
)
{
// moe_align_block_size
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
}
REGISTER_EXTENSION
(
_kernels
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment