Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8f7453e3
Commit
8f7453e3
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to ds3.2
parent
1237aa19
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
199 additions
and
49 deletions
+199
-49
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
...lang/srt/disaggregation/decode_kvcache_offload_manager.py
+1
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+45
-10
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+21
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+7
-7
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+2
-1
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+3
-3
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+19
-19
sgl-kernel/setup_hip.py
sgl-kernel/setup_hip.py
+100
-0
No files found.
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
View file @
8f7453e3
...
...
@@ -4,7 +4,7 @@ import time
import
torch
from
sglang
import
ServerArgs
from
sglang
.srt.server_args
import
ServerArgs
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
...
python/sglang/srt/layers/layernorm.py
View file @
8f7453e3
...
...
@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
return
output
,
residual_out
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
not
x
.
is_contiguous
():
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
if
_vllm_version
<
Version
(
"0.9"
):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
else
:
residual_out
=
torch
.
empty_like
(
x
)
try
:
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
output
,
x
,
...
...
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
)
return
output
,
residual_out
except
TypeError
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
8f7453e3
...
...
@@ -61,7 +61,7 @@ def inplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -79,6 +79,8 @@ def inplace_fused_experts(
gemm1_alpha
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
)
->
None
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -154,7 +156,7 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -173,6 +175,8 @@ def outplace_fused_experts(
gemm1_alpha
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
return
fused_experts_impl
(
hidden_states
,
w1
,
...
...
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -263,6 +267,13 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
topk_weights
,
topk_ids
,
_
=
topk_output
act_id
=
(
0
if
(
moe_runner_config
.
activation
==
0
or
(
isinstance
(
moe_runner_config
.
activation
,
str
)
and
moe_runner_config
.
activation
.
lower
()
==
"silu"
)
)
else
1
)
if
moe_runner_config
.
inplace
:
assert
not
moe_runner_config
.
no_combine
,
"no combine + inplace makes no sense"
torch
.
ops
.
sglang
.
inplace_fused_experts
(
...
...
@@ -273,7 +284,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -301,7 +312,7 @@ def fused_experts(
topk_ids
,
b1
,
b2
,
moe_runner_config
.
activation
,
act_id
,
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
...
...
@@ -345,7 +356,7 @@ def fused_experts_impl(
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
int
=
0
,
#0 silu 1 gelu
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
...
...
@@ -364,6 +375,9 @@ def fused_experts_impl(
gemm1_alpha
:
Optional
[
float
]
=
None
,
gemm1_limit
:
Optional
[
float
]
=
None
,
):
if
isinstance
(
activation
,
int
):
activation
=
"silu"
if
activation
==
0
else
"gelu"
padded_size
=
padding_size
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
padded_size
=
0
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8f7453e3
...
...
@@ -516,7 +516,7 @@ class ModelRunner:
):
server_args
.
attention_backend
=
"fa3"
elif
_is_hip
:
server_args
.
attention_backend
=
"
aiter
"
server_args
.
attention_backend
=
"
triton
"
elif
_is_npu
:
server_args
.
attention_backend
=
"ascend"
else
:
...
...
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
8f7453e3
...
...
@@ -165,10 +165,10 @@ DINLINE void start_sync(
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__
scoped
_atomic_store
_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
__
hip
_atomic_store
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__
HIP_
MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__
scoped
_atomic_load
_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_
DEVICE
)
<
while
(
__
hip
_atomic_load
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__
HIP_
MEMORY_SCOPE_
AGENT
)
<
flag
)
;
}
...
...
@@ -211,16 +211,16 @@ DINLINE void end_sync(
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__
scoped
_atomic_store
_n
(
__
hip
_atomic_store
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
__
HIP_
MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__
scoped
_atomic_load
_n
(
while
(
__
hip
_atomic_load
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
_
_MEMORY_SCOPE_
DEVICE
)
<
flag
)
__HIP
_MEMORY_SCOPE_
AGENT
)
<
flag
)
;
}
__syncthreads
();
...
...
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
8f7453e3
...
...
@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using
Vec
=
int4
;
...
...
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int
original
=
v
;
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
<<=
1
)
{
int
n
=
__shfl_up
_sync
(
mask
,
v
,
offset
);
int
n
=
__shfl_up
(
v
,
offset
);
if
((
threadIdx
.
x
&
(
WARP_SIZE
-
1
))
>=
offset
)
v
+=
n
;
}
return
v
-
original
;
...
...
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
8f7453e3
...
...
@@ -60,7 +60,7 @@ template <typename T>
__device__
float
convert_to_float
(
T
x
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
__half
>
)
{
return
__half2float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__
nv
_bfloat16
>
)
{
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__
hip
_bfloat16
>
)
{
return
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
x
;
...
...
@@ -575,8 +575,8 @@ void topk_softmax(
renormalize
,
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
BFloat16
)
{
topkGatingSoftmaxKernelLauncher
<
__
nv
_bfloat16
>
(
reinterpret_cast
<
const
__
nv
_bfloat16
*>
(
gating_output
.
data_ptr
<
at
::
BFloat16
>
()),
topkGatingSoftmaxKernelLauncher
<
__
hip
_bfloat16
>
(
reinterpret_cast
<
const
__
hip
_bfloat16
*>
(
gating_output
.
data_ptr
<
at
::
BFloat16
>
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
...
...
sgl-kernel/include/utils.h
View file @
8f7453e3
...
...
@@ -358,25 +358,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
#endif
// add FP8 support
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#else
#if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
//
#ifndef USE_ROCM
//
#include <c10/util/Float8_e4m3fn.h>
//
using FP8_TYPE = c10::Float8_e4m3fn;
//
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
//
#else // USE_ROCM
//
#if HIP_FP8_TYPE_FNUZ
//
#include <c10/util/Float8_e4m3fnuz.h>
//
using FP8_TYPE = c10::Float8_e4m3fnuz;
//
constexpr auto FP8_E4M3_MAX = 224.0f;
//
#else
//
#if HIP_FP8_TYPE_E4M3
//
#include <c10/util/Float8_e4m3fn.h>
//
using FP8_TYPE = c10::Float8_e4m3fn;
//
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
//
#else
//
#error "fp8 is not supported in this processor (arch < gfx942)."
//
#endif // HIP_FP8_TYPE_E4M3
//
#endif // HIP_FP8_TYPE_FNUZ
//
#endif // USE_ROCM
#define FULL_MASK 0xffffffff
...
...
sgl-kernel/setup_hip.py
0 → 100644
View file @
8f7453e3
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
platform
import
sys
from
pathlib
import
Path
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
root
=
Path
(
__file__
).
parent
.
resolve
()
arch
=
platform
.
machine
().
lower
()
def
_get_version
():
with
open
(
root
/
"pyproject.toml"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"version"
):
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
operator_namespace
=
"sgl_kernel"
include_dirs
=
[
root
/
"include"
,
root
/
"csrc"
,
]
sources
=
[
"csrc/allreduce/custom_all_reduce.hip"
,
"csrc/allreduce/quick_all_reduce.cu"
,
"csrc/common_extension_rocm.cc"
,
"csrc/elementwise/activation.cu"
,
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/kvcacheio/transfer.cu"
,
]
cxx_flags
=
[
"-O3"
,
"-Wno-switch-bool"
,
"-Wno-macro-redefined"
,
"-Wno-deprecated-declarations"
,
"-w"
,
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
f
"-L/usr/lib/
{
arch
}
-linux-gnu"
]
hipcc_flags
=
[
"-fPIC"
,
"-O3"
,
"-std=c++17"
,
"-D__HIP_PLATFORM_HCC__=1"
,
"--offload-arch=gfx928"
,
"--offload-arch=gfx936"
,
"--gpu-max-threads-per-block=1024"
,
"-Wno-macro-redefined"
,
"-Wno-deprecated-declarations"
,
"-funroll-loops"
,
"-Rpass-analysis=unroll-loops"
,
"-w"
,
]
ext_modules
=
[
CUDAExtension
(
name
=
"sgl_kernel.common_ops"
,
sources
=
sources
,
include_dirs
=
include_dirs
,
extra_compile_args
=
{
"nvcc"
:
hipcc_flags
,
"cxx"
:
cxx_flags
,
},
libraries
=
libraries
,
extra_link_args
=
extra_link_args
,
py_limited_api
=
False
,
),
]
setup
(
name
=
"sgl-kernel"
,
version
=
_get_version
(),
packages
=
find_packages
(
where
=
"python"
),
package_dir
=
{
""
:
"python"
},
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
.
with_options
(
use_ninja
=
True
)},
options
=
{
"bdist_wheel"
:
{
"py_limited_api"
:
"cp39"
}},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment