Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2998c4bd
Unverified
Commit
2998c4bd
authored
Jul 04, 2025
by
Yi Zhang
Committed by
GitHub
Jul 03, 2025
Browse files
[optimize] fuse renormalize into moe_topk_softmax (#7744)
Co-authored-by:
ispobock
<
ispobaoke@gmail.com
>
parent
6840a7bb
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
254 additions
and
101 deletions
+254
-101
sgl-kernel/benchmark/bench_moe_topk_softmax.py
sgl-kernel/benchmark/bench_moe_topk_softmax.py
+0
-4
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-3
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+157
-81
sgl-kernel/csrc/torch_extension_rocm.cc
sgl-kernel/csrc/torch_extension_rocm.cc
+1
-3
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-4
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-2
sgl-kernel/tests/test_moe_topk_softmax.py
sgl-kernel/tests/test_moe_topk_softmax.py
+92
-4
No files found.
sgl-kernel/benchmark/bench_moe_topk_softmax.py
View file @
2998c4bd
...
@@ -34,14 +34,10 @@ def sglang_topk_softmax(gating_output, topk):
...
@@ -34,14 +34,10 @@ def sglang_topk_softmax(gating_output, topk):
topk_indices
=
torch
.
empty
(
topk_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
topk_softmax
(
topk_softmax
(
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_indices
,
topk_ids
=
topk_indices
,
token_expert_indices
=
token_expert_indices
,
gating_output
=
gating_output
,
gating_output
=
gating_output
,
)
)
...
...
sgl-kernel/csrc/common_extension.cc
View file @
2998c4bd
...
@@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"
);
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"
);
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
2998c4bd
...
@@ -41,15 +41,29 @@ template <
...
@@ -41,15 +41,29 @@ template <
/// Alignment requirement in bytes
/// Alignment requirement in bytes
int
Alignment
=
sizeof
(
T
)
*
N
>
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
class
alignas
(
Alignment
)
AlignedArray
{
float
data
[
N
];
T
data
[
N
];
};
};
// ========================== Util functions to convert types ==========================
template
<
typename
T
>
__device__
float
convert_to_float
(
T
x
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
__half
>
)
{
return
__half2float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__nv_bfloat16
>
)
{
return
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
x
;
}
else
{
return
static_cast
<
float
>
(
x
);
}
}
// ====================== Softmax things ===============================
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
// in the softmax kernel when we extend this module to support expert-choice routing.
template
<
int
TPB
>
template
<
typename
T
,
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
__launch_bounds__
(
TPB
)
__global__
void
moeSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
void
moeSoftmax
(
const
T
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
TPB
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
...
@@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__
...
@@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
int
idx
=
thread_row_offset
+
ii
;
threadData
=
max
(
static_cast
<
float
>
(
input
[
idx
]),
threadData
);
threadData
=
max
(
convert_to_
float
<
T
>
(
input
[
idx
]),
threadData
);
}
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
cub
::
Max
());
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
cub
::
Max
());
...
@@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__
...
@@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
int
idx
=
thread_row_offset
+
ii
;
threadData
+=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
));
threadData
+=
exp
((
convert_to_
float
<
T
>
(
input
[
idx
])
-
float_max
));
}
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
sum
);
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
sum
);
...
@@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__
...
@@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
int
idx
=
thread_row_offset
+
ii
;
const
float
val
=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
))
*
normalizing_factor
;
const
float
val
=
exp
((
convert_to_
float
<
T
>
(
input
[
idx
])
-
float_max
))
*
normalizing_factor
;
output
[
idx
]
=
val
;
output
[
idx
]
=
val
;
}
}
}
}
...
@@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
...
@@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const
bool
*
finished
,
const
bool
*
finished
,
float
*
output
,
float
*
output
,
int
*
indices
,
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
num_experts
,
const
int
k
,
const
int
k
,
const
int
start_expert
,
const
int
start_expert
,
const
int
end_expert
)
{
const
int
end_expert
,
const
bool
renormalize
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
cub_kvp
,
TPB
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
cub_kvp
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
...
@@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
...
@@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
cub_kvp
thread_kvp
;
cub_kvp
thread_kvp
;
cub
::
ArgMax
arg_max
;
cub
::
ArgMax
arg_max
;
const
int
num_rows
=
gridDim
.
x
;
const
int
block_row
=
blockIdx
.
x
;
const
int
block_row
=
blockIdx
.
x
;
const
bool
row_is_active
=
finished
?
!
finished
[
block_row
]
:
true
;
const
bool
row_is_active
=
finished
?
!
finished
[
block_row
]
:
true
;
const
int
thread_read_offset
=
blockIdx
.
x
*
num_experts
;
const
int
thread_read_offset
=
blockIdx
.
x
*
num_experts
;
float
row_sum_for_renormalize
=
0
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
thread_kvp
.
key
=
0
;
thread_kvp
.
key
=
0
;
thread_kvp
.
value
=
-
1.
f
;
// This is OK because inputs are probabilities
thread_kvp
.
value
=
-
1.
f
;
// This is OK because inputs are probabilities
...
@@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK(
...
@@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK(
output
[
idx
]
=
result_kvp
.
value
;
output
[
idx
]
=
result_kvp
.
value
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
num_experts
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
num_experts
;
assert
(
indices
[
idx
]
>=
0
);
assert
(
indices
[
idx
]
>=
0
);
source_rows
[
idx
]
=
k_idx
*
num_rows
+
block_row
;
row_sum_for_renormalize
+=
result_kvp
.
value
;
}
}
__syncthreads
();
__syncthreads
();
}
}
if
(
renormalize
&&
threadIdx
.
x
==
0
)
{
float
row_sum_for_renormalize_inv
=
1.
f
/
row_sum_for_renormalize
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
const
int
idx
=
k
*
block_row
+
k_idx
;
output
[
idx
]
=
output
[
idx
]
*
row_sum_for_renormalize_inv
;
}
}
}
}
// ====================== TopK softmax things ===============================
// ====================== TopK softmax things ===============================
...
@@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK(
...
@@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
2) This implementation assumes k is small, but will work for any k.
*/
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
template
<
typename
T
,
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
T
*
input
,
const
bool
*
finished
,
const
bool
*
finished
,
float
*
output
,
float
*
output
,
const
int
num_rows
,
const
int
num_rows
,
int
*
indices
,
int
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
k
,
const
int
start_expert
,
const
int
start_expert
,
const
int
end_expert
)
{
const
int
end_expert
,
const
bool
renormalize
)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert
(
VPT
==
(
VPT
&
-
VPT
),
"VPT must be power of 2"
);
static_assert
(
VPT
==
(
VPT
&
-
VPT
),
"VPT must be power of 2"
);
static_assert
(
NUM_EXPERTS
==
(
NUM_EXPERTS
&
-
NUM_EXPERTS
),
"NUM_EXPERTS must be power of 2"
);
static_assert
(
NUM_EXPERTS
==
(
NUM_EXPERTS
&
-
NUM_EXPERTS
),
"NUM_EXPERTS must be power of 2"
);
...
@@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
...
@@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
static_assert
(
BYTES_PER_LDG
<=
16
,
"BYTES_PER_LDG must be leq 16"
);
static_assert
(
BYTES_PER_LDG
<=
16
,
"BYTES_PER_LDG must be leq 16"
);
// Number of bytes each thread pulls in per load
// Number of bytes each thread pulls in per load
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
T
);
static
constexpr
int
ELTS_PER_ROW
=
NUM_EXPERTS
;
static
constexpr
int
ELTS_PER_ROW
=
NUM_EXPERTS
;
static
constexpr
int
THREADS_PER_ROW
=
ELTS_PER_ROW
/
VPT
;
static
constexpr
int
THREADS_PER_ROW
=
ELTS_PER_ROW
/
VPT
;
static
constexpr
int
LDG_PER_THREAD
=
VPT
/
ELTS_PER_LDG
;
static
constexpr
int
LDG_PER_THREAD
=
VPT
/
ELTS_PER_LDG
;
...
@@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
...
@@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
// row it will read.
const
float
*
thread_row_ptr
=
input
+
thread_row
*
ELTS_PER_ROW
;
const
T
*
thread_row_ptr
=
input
+
thread_row
*
ELTS_PER_ROW
;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const
int
thread_group_idx
=
threadIdx
.
x
%
THREADS_PER_ROW
;
const
int
thread_group_idx
=
threadIdx
.
x
%
THREADS_PER_ROW
;
const
int
first_elt_read_by_thread
=
thread_group_idx
*
ELTS_PER_LDG
;
const
int
first_elt_read_by_thread
=
thread_group_idx
*
ELTS_PER_LDG
;
const
float
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
const
T
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using
AccessType
=
AlignedArray
<
float
,
ELTS_PER_LDG
>
;
using
AccessType
=
AlignedArray
<
T
,
ELTS_PER_LDG
>
;
// Finally, we pull in the data from global mem
// Finally, we pull in the data from global mem
float
row_chunk
[
VPT
];
T
row_chunk
_temp
[
VPT
];
AccessType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
row_chunk
);
AccessType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
row_chunk
_temp
);
const
AccessType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
thread_read_ptr
);
const
AccessType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
thread_read_ptr
);
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
}
}
float
row_chunk
[
VPT
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
convert_to_float
<
T
>
(
row_chunk_temp
[
ii
]);
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
// convert to float afterwards for the exp + sum reduction.
float
thread_max
=
row_chunk
[
0
];
float
thread_max
=
row_chunk
[
0
];
...
@@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
...
@@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
int
start_col
=
first_elt_read_by_thread
;
int
start_col
=
first_elt_read_by_thread
;
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
float
row_sum_for_renormalize
=
0
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
// First, each thread does the local argmax
// First, each thread does the local argmax
float
max_val
=
row_chunk
[
0
];
float
max_val
=
row_chunk
[
0
];
...
@@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
...
@@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const
int
idx
=
k
*
thread_row
+
k_idx
;
const
int
idx
=
k
*
thread_row
+
k_idx
;
output
[
idx
]
=
max_val
;
output
[
idx
]
=
max_val
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
NUM_EXPERTS
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
NUM_EXPERTS
;
source_rows
[
idx
]
=
k_idx
*
num_rows
+
thread_row
;
row_sum_for_renormalize
+=
max_val
;
}
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
...
@@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
...
@@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
}
}
}
}
}
}
// Fuse renormalization of topk_weights into this kernel
if
(
renormalize
&&
thread_group_idx
==
0
)
{
float
row_sum_for_renormalize_inv
=
1.
f
/
row_sum_for_renormalize
;
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
const
int
idx
=
k
*
thread_row
+
k_idx
;
output
[
idx
]
=
output
[
idx
]
*
row_sum_for_renormalize_inv
;
}
}
}
}
namespace
detail
{
namespace
detail
{
// Constructs some constants needed to partition the work across threads at compile time.
// Constructs some constants needed to partition the work across threads at compile time.
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
>
template
<
typename
T
,
int
EXPERTS
,
int
BYTES_PER_LDG
>
struct
TopkConstants
{
struct
TopkConstants
{
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
T
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
,
""
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
,
""
);
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
...
@@ -377,100 +417,84 @@ struct TopkConstants {
...
@@ -377,100 +417,84 @@ struct TopkConstants {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
template
<
typename
T
,
int
EXPERTS
,
int
WARPS_PER_TB
>
void
topkGatingSoftmaxLauncherHelper
(
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
T
*
input
,
const
bool
*
finished
,
const
bool
*
finished
,
float
*
output
,
float
*
output
,
int
*
indices
,
int
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
num_rows
,
const
int
k
,
const
int
k
,
const
int
start_expert
,
const
int
start_expert
,
const
int
end_expert
,
const
int
end_expert
,
const
bool
renormalize
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
T
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
>
;
using
Constants
=
detail
::
TopkConstants
<
T
,
EXPERTS
,
BYTES_PER_LDG
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_TB
);
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_TB
);
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
topkGatingSoftmax
<
T
,
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
);
input
,
finished
,
output
,
num_rows
,
indices
,
k
,
start_expert
,
end_expert
,
renormalize
);
}
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, \
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
nullptr, \
topk_weights, \
topk_indices, \
token_expert_indices, \
num_tokens, \
topk, \
0, \
num_experts, \
stream);
template
<
typename
T
>
void
topkGatingSoftmaxKernelLauncher
(
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
const
T
*
gating_output
,
float
*
topk_weights
,
float
*
topk_weights
,
int
*
topk_indices
,
int
*
topk_indices
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
float
*
softmax_workspace
,
const
int
num_tokens
,
const
int
num_tokens
,
const
int
num_experts
,
const
int
num_experts
,
const
int
topk
,
const
int
topk
,
const
bool
renormalize
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
static
constexpr
int
WARPS_PER_TB
=
4
;
switch
(
num_experts
)
{
switch
(
num_experts
)
{
case
1
:
case
1
:
LAUNCH_SOFTMAX
(
1
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
1
,
WARPS_PER_TB
);
break
;
break
;
case
2
:
case
2
:
LAUNCH_SOFTMAX
(
2
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
2
,
WARPS_PER_TB
);
break
;
break
;
case
4
:
case
4
:
LAUNCH_SOFTMAX
(
4
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
4
,
WARPS_PER_TB
);
break
;
break
;
case
8
:
case
8
:
LAUNCH_SOFTMAX
(
8
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
8
,
WARPS_PER_TB
);
break
;
break
;
case
16
:
case
16
:
LAUNCH_SOFTMAX
(
16
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
16
,
WARPS_PER_TB
);
break
;
break
;
case
32
:
case
32
:
LAUNCH_SOFTMAX
(
32
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
32
,
WARPS_PER_TB
);
break
;
break
;
case
64
:
case
64
:
LAUNCH_SOFTMAX
(
64
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
64
,
WARPS_PER_TB
);
break
;
break
;
case
128
:
case
128
:
LAUNCH_SOFTMAX
(
128
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
128
,
WARPS_PER_TB
);
break
;
break
;
case
256
:
case
256
:
LAUNCH_SOFTMAX
(
256
,
WARPS_PER_TB
);
LAUNCH_SOFTMAX
(
T
,
256
,
WARPS_PER_TB
);
break
;
break
;
default:
{
default:
{
TORCH_CHECK
(
TORCH_CHECK
(
softmax_workspace
!=
nullptr
,
softmax_workspace
!=
nullptr
,
"softmax_workspace must be provided for num_experts that are not a power of 2."
);
"softmax_workspace must be provided for num_experts that are not a power of 2."
);
static
constexpr
int
TPB
=
256
;
static
constexpr
int
TPB
=
256
;
moeSoftmax
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeSoftmax
<
T
,
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_workspace
,
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indices
,
num_experts
,
topk
,
0
,
num_experts
,
renormalize
);
nullptr
,
topk_weights
,
topk_indices
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
);
}
}
}
}
}
}
...
@@ -478,12 +502,35 @@ void topkGatingSoftmaxKernelLauncher(
...
@@ -478,12 +502,35 @@ void topkGatingSoftmaxKernelLauncher(
void
topk_softmax
(
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
token_expert_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
,
torch
::
T
en
s
or
&
gating_output
)
// [num_tokens, num_experts]
const
bool
r
enor
malize
)
// [num_tokens, num_experts]
{
{
const
int
num_experts
=
gating_output
.
size
(
-
1
);
// Check data type
const
int
num_tokens
=
gating_output
.
numel
()
/
num_experts
;
TORCH_CHECK
(
const
int
topk
=
topk_weights
.
size
(
-
1
);
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"gating_output must be float32, float16, or bfloat16"
);
// Check dimensions
TORCH_CHECK
(
gating_output
.
dim
()
==
2
,
"gating_output must be 2D tensor [num_tokens, num_experts]"
);
TORCH_CHECK
(
topk_weights
.
dim
()
==
2
,
"topk_weights must be 2D tensor [num_tokens, topk]"
);
TORCH_CHECK
(
topk_indices
.
dim
()
==
2
,
"topk_indices must be 2D tensor [num_tokens, topk]"
);
// Check shapes
TORCH_CHECK
(
gating_output
.
size
(
0
)
==
topk_weights
.
size
(
0
),
"First dimension of topk_weights must match num_tokens in gating_output"
);
TORCH_CHECK
(
gating_output
.
size
(
0
)
==
topk_indices
.
size
(
0
),
"First dimension of topk_indices must match num_tokens in gating_output"
);
TORCH_CHECK
(
topk_weights
.
size
(
-
1
)
==
topk_indices
.
size
(
-
1
),
"Second dimension of topk_indices must match topk in topk_weights"
);
TORCH_CHECK
(
topk_weights
.
size
(
-
1
)
<=
gating_output
.
size
(
-
1
),
"topk must be less than or equal to num_experts"
);
const
int
num_experts
=
static_cast
<
int
>
(
gating_output
.
size
(
-
1
));
const
int
num_tokens
=
static_cast
<
int
>
(
gating_output
.
size
(
0
));
const
int
topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
-
1
));
const
bool
is_pow_2
=
(
num_experts
!=
0
)
&&
((
num_experts
&
(
num_experts
-
1
))
==
0
);
const
bool
is_pow_2
=
(
num_experts
!=
0
)
&&
((
num_experts
&
(
num_experts
-
1
))
==
0
);
const
bool
needs_workspace
=
!
is_pow_2
||
num_experts
>
256
;
const
bool
needs_workspace
=
!
is_pow_2
||
num_experts
>
256
;
...
@@ -491,15 +538,44 @@ void topk_softmax(
...
@@ -491,15 +538,44 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
torch
::
Tensor
softmax_workspace
=
topkGatingSoftmaxKernelLauncher
(
torch
::
empty
({
workspace_size
},
gating_output
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
const
at
::
ScalarType
dtype
=
gating_output
.
scalar_type
();
if
(
dtype
==
at
::
ScalarType
::
Float
)
{
topkGatingSoftmaxKernelLauncher
<
float
>
(
gating_output
.
data_ptr
<
float
>
(),
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
topk
,
topk
,
renormalize
,
stream
);
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
Half
)
{
topkGatingSoftmaxKernelLauncher
<
__half
>
(
reinterpret_cast
<
const
__half
*>
(
gating_output
.
data_ptr
<
at
::
Half
>
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
BFloat16
)
{
topkGatingSoftmaxKernelLauncher
<
__nv_bfloat16
>
(
reinterpret_cast
<
const
__nv_bfloat16
*>
(
gating_output
.
data_ptr
<
at
::
BFloat16
>
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported gating_output dtype: "
,
dtype
);
}
}
}
sgl-kernel/csrc/torch_extension_rocm.cc
View file @
2998c4bd
...
@@ -63,9 +63,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -63,9 +63,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"
);
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"
);
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
/*
/*
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
2998c4bd
...
@@ -222,10 +222,7 @@ void moe_align_block_size(
...
@@ -222,10 +222,7 @@ void moe_align_block_size(
bool
pad_sorted_token_ids
);
bool
pad_sorted_token_ids
);
void
topk_softmax
(
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
);
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
input
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
2998c4bd
...
@@ -30,11 +30,11 @@ def moe_align_block_size(
...
@@ -30,11 +30,11 @@ def moe_align_block_size(
def
topk_softmax
(
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
gating_output
:
float
,
renormalize
:
bool
=
False
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
topk_weights
,
topk_ids
,
gating_output
,
renormalize
)
)
...
...
sgl-kernel/tests/test_moe_topk_softmax.py
View file @
2998c4bd
...
@@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
...
@@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_softmax
(
topk_weights
,
topk_weights
,
topk_indices
,
topk_indices
,
token_expert_indices
,
gating_output
,
gating_output
,
)
)
...
@@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk):
...
@@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk):
),
f
"Indices mismatch: torch=
{
topk_indices_ref
}
, SGLang=
{
topk_indices
}
"
),
f
"Indices mismatch: torch=
{
topk_indices_ref
}
, SGLang=
{
topk_indices
}
"
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_experts, topk, dtype"
,
list
(
itertools
.
product
(
[
1
,
16
,
128
,
512
,
1024
,
2048
],
# num_tokens
[
4
,
8
,
16
,
32
,
64
,
128
,
256
],
# num_experts
[
1
,
2
,
4
],
# topk
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
# dtype
)
),
)
def
test_topk_softmax_dtype_regression
(
num_tokens
,
num_experts
,
topk
,
dtype
):
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights
,
topk_indices
,
gating_output
,
)
topk_weights_ref
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices_ref
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights_ref
,
topk_indices_ref
,
gating_output
.
float
(),
)
assert
torch
.
allclose
(
topk_weights_ref
,
topk_weights
,
atol
=
1e-3
,
rtol
=
1e-3
),
f
"Weights mismatch: SGLang old interface=
{
topk_indices_ref
}
vs SGLang new interface=
{
topk_weights
}
"
assert
torch
.
allclose
(
topk_indices_ref
.
int
(),
topk_indices
,
atol
=
0
,
rtol
=
0
),
f
"Indices mismatch: SGLang old interface=
{
topk_indices_ref
}
, SGLang new interface=
{
topk_indices
}
"
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_experts, topk"
,
list
(
itertools
.
product
(
[
1
,
16
,
128
,
512
,
1024
,
2048
],
# num_tokens
[
4
,
8
,
16
,
32
,
64
,
128
,
256
],
# num_experts
[
1
,
2
,
4
],
# topk
)
),
)
def
test_topk_softmax_renormalize
(
num_tokens
,
num_experts
,
topk
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights
,
topk_indices
,
gating_output
,
renormalize
=
True
,
)
topk_weights_ref
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices_ref
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_expert_indices_ref
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights_ref
,
topk_indices_ref
,
gating_output
,
)
topk_weights_ref
=
topk_weights_ref
/
topk_weights_ref
.
sum
(
dim
=-
1
,
keepdim
=
True
)
assert
torch
.
allclose
(
topk_weights_ref
,
topk_weights
,
atol
=
1e-3
,
rtol
=
1e-3
),
f
"Weights mismatch: SGLang w/o fused renormalize=
{
topk_indices_ref
}
vs SGLang w/ fused renormalize=
{
topk_weights
}
"
assert
torch
.
allclose
(
topk_indices_ref
.
int
(),
topk_indices
,
atol
=
0
,
rtol
=
0
),
f
"Indices mismatch: SGLang w/o fused renormalize=
{
topk_indices_ref
}
, SGLang w/ fused renormalize=
{
topk_indices
}
"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment