Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
18bb216c
Unverified
Commit
18bb216c
authored
Feb 28, 2025
by
Chayenne
Committed by
GitHub
Feb 28, 2025
Browse files
Revert "[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x)" (#3982)
parent
6b859e7d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
380 deletions
+93
-380
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+38
-81
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+52
-266
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+0
-30
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+2
-2
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
View file @
18bb216c
...
...
@@ -99,12 +99,13 @@ def moe_align_block_size_triton(
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
tokens_cnts
:
torch
.
Tensor
,
cumsum
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,)
tokens_cnts
=
torch
.
zeros
(
(
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
...
...
@@ -138,18 +139,11 @@ def moe_align_block_size_triton(
)
def
calculate_diff
(
batch_size
,
seq_len
,
num_experts
):
num_experts
=
num_experts
def
calculate_diff
(
batch_size
,
seq_len
):
num_experts
=
256
block_size
=
128
topk
=
8
assert
batch_size
>=
1
assert
seq_len
>=
1
assert
num_experts
>=
4
if
topk
>
num_experts
:
topk
=
num_experts
topk_ids
=
torch
.
stack
(
[
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
...
...
@@ -181,13 +175,6 @@ def calculate_diff(batch_size, seq_len, num_experts):
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
token_cnts_buffer_triton
=
torch
.
zeros
(
(
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum_buffer_triton
=
torch
.
zeros
(
(
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
# compare the performance of cuda and triton implementation
moe_align_block_size
(
topk_ids
,
...
...
@@ -206,27 +193,14 @@ def calculate_diff(batch_size, seq_len, num_experts):
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
token_cnts_buffer_triton
,
cumsum_buffer_triton
,
)
sorted_ids_cuda_snapshot
=
sorted_ids_cuda
[:
cumsum_buffer
[
1
]].
sort
().
values
sorted_ids_triton_snapshot
=
sorted_ids_triton
[:
cumsum_buffer
[
1
]].
sort
().
values
if
(
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
)
and
torch
.
allclose
(
sorted_ids_cuda_snapshot
,
sorted_ids_triton_snapshot
)
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
):
print
(
"✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}"
.
format
(
batch_size
*
seq_len
,
num_experts
)
)
print
(
"✅ CUDA and Triton implementations match"
)
else
:
print
(
"❌ CUDA and Triton implementations do not match"
)
print
(
"CUDA sorted ids:"
,
sorted_ids_cuda_snapshot
)
print
(
"Triton sorted ids:"
,
sorted_ids_triton_snapshot
)
print
(
"CUDA expert_ids:"
,
expert_ids_cuda
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"CUDA num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
...
...
@@ -282,7 +256,7 @@ def benchmark(batch_size, seq_len, provider):
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
zeros
(
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
...
...
@@ -293,37 +267,34 @@ def benchmark(batch_size, seq_len, provider):
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
# Warm up
api_func
=
(
moe_align_block_size
if
provider
==
"cuda"
else
moe_align_block_size_triton
)
for
_
in
range
(
10
):
api_func
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
token_cnts_buffer
.
clone
(),
cumsum_buffer
.
clone
(),
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"cuda"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
token_cnts_buffer
,
cumsum_buffer
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
torch
.
cuda
.
synchronize
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
api_func
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
token_cnts_buffer
.
clone
(),
cumsum_buffer
.
clone
(),
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
@@ -335,22 +306,8 @@ if __name__ == "__main__":
default
=
"./configs/benchmark_ops/moe_align_blocks/"
,
help
=
"Path to save moe align benchmark results"
,
)
parser
.
add_argument
(
"--verify"
,
action
=
"store_true"
,
help
=
"verify kernel"
,
)
args
=
parser
.
parse_args
()
if
args
.
verify
:
num_experts_range
=
[
2
**
i
for
i
in
range
(
3
,
9
)]
calculate_diff
(
batch_size
=
4
,
seq_len
=
1024
)
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_length_range
,
num_experts_range
)
)
for
bs
,
seq
,
num_experts
in
configs
:
calculate_diff
(
batch_size
=
bs
,
seq_len
=
seq
,
num_experts
=
num_experts
)
else
:
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/pyproject.toml
View file @
18bb216c
[build-system]
requires
=
[
"setuptools>=61.0"
,
"wheel"
,
"torch
<=2.5.1
"
]
requires
=
[
"setuptools>=61.0"
,
"wheel"
,
"torch"
]
build-backend
=
"setuptools.build_meta"
[project]
...
...
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
18bb216c
...
...
@@ -16,284 +16,77 @@ limitations under the License.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cooperative_groups.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "utils.h"
#define MAX_NUM_EXPERTS 256
#define EXPERTS_PER_WARP ((MAX_NUM_EXPERTS) / (WARP_SIZE))
#define WARP_SIZE 32
#define FRAGS_PER_BLOCK 4
#define FRAG_SIZE_M 16
#define FRAG_SIZE_N 16
#ifndef USE_ROCM
#define kWarpsToLoad 2
#else
#define kWarpsToLoad 1
#endif
#define kElementsPerAccess 4
#define kElementsPerThr 16
#define SGLANG_FORCE_INLINE_DEVICE_FUNC static __forceinline__ __attribute__((always_inline)) __device__
namespace
cg
=
cooperative_groups
;
SGLANG_FORCE_INLINE_DEVICE_FUNC
void
store_global_cumsum
(
int
*
cumsum
/*dest*/
,
int
*
total_tokens_post_pad
/*dest*/
,
const
int32_t
*
local_offsets
,
const
int
&
tid
,
const
int
&
num_experts
,
cg
::
grid_group
&
grid
)
{
int
active_threads
=
CEILDIV
(
num_experts
+
1
,
kElementsPerThr
);
if
(
tid
<
active_threads
-
1
)
{
for
(
int
i
=
tid
*
kElementsPerThr
;
i
<
(
tid
+
1
)
*
kElementsPerThr
;
i
+=
kElementsPerAccess
)
{
*
(
int4
*
)(
cumsum
+
i
)
=
*
(
int4
*
)(
local_offsets
+
i
);
}
}
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
if
(
tid
==
active_threads
-
1
)
{
#pragma unroll
for
(
int
i
=
tid
*
kElementsPerThr
;
i
<
num_experts
+
1
;
i
++
)
{
*
(
cumsum
+
i
)
=
*
(
local_offsets
+
i
);
}
}
if
(
tid
==
active_threads
)
{
*
total_tokens_post_pad
=
local_offsets
[
num_experts
];
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum_buffer
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
__threadfence_system
();
grid
.
sync
();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC
void
align_global_cumsum
(
int32_t
*
local_offsets
/*src_and_dest*/
,
int32_t
*
local_offsets_buf
,
int
*
smem_ptr
,
const
int
tid
,
const
int32_t
&
block_size
,
const
int32_t
&
num_experts
)
{
int
active_threads
=
CEILDIV
(
num_experts
,
kElementsPerThr
);
int
start
=
tid
*
kElementsPerThr
+
1
;
int
end
=
MIN
((
tid
+
1
)
*
kElementsPerThr
,
num_experts
)
+
1
;
if
(
tid
==
0
)
{
smem_ptr
[
0
]
=
0
;
}
if
(
tid
<
active_threads
)
{
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
smem_ptr
[
i
]
=
local_offsets
[
i
]
-
local_offsets
[
i
-
1
];
}
}
__syncthreads
();
if
(
tid
<
active_threads
)
{
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
int
last_val
=
(
i
-
1
)
%
kElementsPerThr
==
0
?
0
:
local_offsets
[
i
-
1
];
local_offsets
[
i
]
=
last_val
+
CEILDIV
(
smem_ptr
[
i
],
block_size
)
*
block_size
;
}
local_offsets_buf
[
tid
]
=
local_offsets
[
end
-
1
];
}
__syncthreads
();
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
if
(
tid
<
active_threads
&&
tid
>
0
)
{
int
offset
=
0
;
for
(
int
j
=
0
;
j
<
tid
;
++
j
)
{
offset
+=
local_offsets_buf
[
j
];
}
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
experts_per_warp
=
8
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
local_offsets
[
i
]
+=
offset
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
shared_counts
[
warp_id
][
i
]
=
0
;
}
}
__syncthreads
();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC
void
reduce_unaligned_cumsum
(
int
*
tokens_cnts_ptr
/*src_and_dest*/
,
int
*
smem_ptr
,
int32_t
*
local_offsets
,
const
int
&
tid
,
const
int
&
lane_id
,
const
int
&
warp_id
,
const
int32_t
&
num_experts
,
cg
::
grid_group
&
grid
)
{
int
total_fragments
=
CEILDIV
(
num_experts
,
FRAG_SIZE_N
);
int
fragments_per_block
=
CEILDIV
(
total_fragments
,
gridDim
.
x
);
int
fragments_per_warp
=
CEILDIV
(
fragments_per_block
,
FRAGS_PER_BLOCK
);
for
(
int
i
=
0
;
i
<
gridDim
.
x
;
i
+=
FRAG_SIZE_M
)
{
for
(
int
j
=
0
;
j
<
fragments_per_warp
;
j
++
)
{
if
(
warp_id
*
fragments_per_warp
<
kWarpsToLoad
*
fragments_per_block
)
{
const
int
kNumThrPerRow
=
WARP_SIZE
/
FRAG_SIZE_N
;
int
sRow
=
lane_id
/
kNumThrPerRow
;
int
sWarpColStride
=
kNumThrPerRow
*
kElementsPerAccess
;
int
sWarpColOff
=
warp_id
*
sWarpColStride
;
int
sThrColOff
=
lane_id
%
kNumThrPerRow
*
kElementsPerAccess
;
int
sCol
=
sThrColOff
+
sWarpColOff
;
int
gRow
=
i
+
sRow
;
int
gBlockColOff
=
blockIdx
.
x
*
fragments_per_block
*
FRAG_SIZE_N
;
int
gWarpColOff_0
=
(
warp_id
/
kWarpsToLoad
*
fragments_per_warp
+
j
)
*
FRAG_SIZE_N
;
int
gWarpColOff_1
=
warp_id
%
kWarpsToLoad
*
sWarpColStride
;
int
gCol
=
gBlockColOff
+
gWarpColOff_0
+
gWarpColOff_1
+
sThrColOff
;
if
(
gRow
<
num_experts
&&
gCol
<
num_experts
)
{
int4
*
tokens_cnts_4i_ptr
=
(
int4
*
)(
tokens_cnts_ptr
+
(
gRow
+
1
)
*
num_experts
+
gCol
);
int4
*
smem_4i_ptr
=
(
int4
*
)(
smem_ptr
+
sRow
*
FRAGS_PER_BLOCK
*
FRAG_SIZE_N
+
sCol
);
*
smem_4i_ptr
=
*
tokens_cnts_4i_ptr
;
}
}
__syncthreads
();
if
(
warp_id
*
fragments_per_warp
<
kWarpsToLoad
*
fragments_per_block
)
{
if
(
warp_id
%
kWarpsToLoad
==
0
)
{
for
(
int
k
=
0
;
k
<
FRAG_SIZE_M
;
k
+=
(
WARP_SIZE
/
FRAG_SIZE_N
))
{
int
sRow
=
lane_id
/
FRAG_SIZE_N
+
k
;
int
sThrColOff
=
lane_id
%
FRAG_SIZE_N
;
int
sCol
=
sThrColOff
+
(
warp_id
/
kWarpsToLoad
)
*
FRAG_SIZE_N
;
int
gBlockColOff
=
blockIdx
.
x
*
fragments_per_block
*
FRAG_SIZE_N
;
int
gWarpColOff_0
=
(
warp_id
/
kWarpsToLoad
*
fragments_per_warp
+
j
)
*
FRAG_SIZE_N
;
int
gCol
=
gBlockColOff
+
gWarpColOff_0
+
sThrColOff
;
if
(
gCol
<
num_experts
)
{
atomicAdd
(
local_offsets
+
gCol
+
1
,
*
(
smem_ptr
+
sRow
*
FRAGS_PER_BLOCK
*
FRAG_SIZE_N
+
sCol
));
// atomicAdd(tokens_cnts_ptr + gCol, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol));
}
}
}
}
__syncthreads
();
__syncthreads
();
}
// end of j
}
// end of i
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
if
(
threadIdx
.
x
<
num_experts
)
{
atomicAdd
(
tokens_cnts_ptr
+
threadIdx
.
x
,
*
(
local_offsets
+
threadIdx
.
x
+
1
));
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
][
expert_offset
],
1
);
}
__threadfence_system
();
grid
.
sync
();
if
(
tid
<
num_experts
)
{
*
(
local_offsets
+
tid
+
1
)
=
*
(
tokens_cnts_ptr
+
tid
);
}
__syncthreads
();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC
void
parallel_unaligned_local_cumsum
(
const
int
&
tid
,
int
*
tokens_cnts_ptr
/*dest*/
,
int32_t
*
local_offsets
/*dest*/
,
int32_t
*
local_offsets_buf
,
const
int32_t
(
*
shared_counts
)[
EXPERTS_PER_WARP
]
/*src*/
,
const
int
&
experts_per_warp
,
const
int32_t
&
num_experts
,
cg
::
grid_group
&
grid
)
{
int
active_threads
=
CEILDIV
(
num_experts
,
kElementsPerThr
);
if
(
threadIdx
.
x
==
0
)
{
local_offsets
[
0
]
=
0
;
}
if
(
threadIdx
.
x
<
active_threads
)
{
for
(
int
i
=
threadIdx
.
x
*
kElementsPerThr
+
1
;
i
<
MIN
((
threadIdx
.
x
+
1
)
*
kElementsPerThr
,
num_experts
)
+
1
;
++
i
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
][
expert_offset
];
int
expert_count
=
shared_counts
[
warp_idx
][
expert_offset
];
int
last_val
=
(
i
-
1
)
%
kElementsPerThr
==
0
?
0
:
local_offsets
[
i
-
1
];
local_offsets
[
i
]
=
last_val
+
expert_count
;
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
local_offsets_buf
[
threadIdx
.
x
]
=
local_offsets
[
MIN
((
threadIdx
.
x
+
1
)
*
kElementsPerThr
,
num_experts
)];
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
active_threads
&&
threadIdx
.
x
>
0
)
{
int
offset
=
0
;
for
(
int
j
=
0
;
j
<
threadIdx
.
x
;
++
j
)
{
offset
+=
local_offsets_buf
[
j
];
}
for
(
int
i
=
threadIdx
.
x
*
kElementsPerThr
+
1
;
i
<
MIN
((
threadIdx
.
x
+
1
)
*
kElementsPerThr
,
num_experts
)
+
1
;
++
i
)
{
local_offsets
[
i
]
+=
offset
;
}
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
*
(
tokens_cnts_ptr
+
tid
)
=
0
;
}
if
(
threadIdx
.
x
<
num_experts
)
{
*
(
tokens_cnts_ptr
+
(
blockIdx
.
x
+
1
)
*
num_experts
+
threadIdx
.
x
)
=
*
(
local_offsets
+
threadIdx
.
x
+
1
);
*
(
local_offsets
+
threadIdx
.
x
+
1
)
=
0
;
}
else
if
(
threadIdx
.
x
<
MAX_NUM_EXPERTS
)
{
*
(
local_offsets
+
threadIdx
.
x
+
1
)
=
0
;
}
__threadfence_system
();
grid
.
sync
();
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
tokens_cnts
,
int32_t
*
__restrict__
cumsum
,
const
int
tokens_per_block
,
const
int
tokens_per_thread
,
const
int
K
)
{
__shared__
int32_t
smem
[
FRAG_SIZE_M
*
FRAG_SIZE_N
*
FRAGS_PER_BLOCK
];
int32_t
(
*
shared_counts
)[
EXPERTS_PER_WARP
]
=
(
int32_t
(
*
)[
EXPERTS_PER_WARP
])
&
smem
[
0
];
__shared__
int32_t
local_offsets
[
MAX_NUM_EXPERTS
+
1
];
__shared__
int32_t
local_offsets_buf
[
CEILDIV
(
MAX_NUM_EXPERTS
,
kElementsPerThr
)];
const
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
experts_per_warp
=
EXPERTS_PER_WARP
;
int
*
tokens_cnts_ptr
=
&
(
tokens_cnts
[
0
]);
int
*
smem_ptr
=
&
(
smem
[
0
]);
cg
::
grid_group
grid
=
cg
::
this_grid
();
if
(
threadIdx
.
x
<
FRAG_SIZE_M
*
FRAG_SIZE_N
)
{
for
(
int
i
=
0
;
i
<
FRAG_SIZE_M
*
FRAG_SIZE_N
*
FRAGS_PER_BLOCK
;
i
+=
FRAG_SIZE_M
*
FRAG_SIZE_N
)
{
smem
[
threadIdx
.
x
+
i
]
=
0
;
}
}
__syncthreads
();
const
size_t
start_idx
=
tokens_per_block
*
blockIdx
.
x
+
tokens_per_thread
*
threadIdx
.
x
;
const
size_t
end_idx
=
start_idx
+
tokens_per_thread
;
if
(
threadIdx
.
x
*
tokens_per_thread
<
tokens_per_block
)
{
for
(
int
i
=
start_idx
;
i
<
MIN
(
numel
,
end_idx
);
++
i
)
{
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
][
expert_offset
],
1
);
}
}
__syncthreads
();
parallel_unaligned_local_cumsum
(
tid
,
tokens_cnts_ptr
/*dest*/
,
local_offsets
,
local_offsets_buf
,
shared_counts
,
experts_per_warp
,
num_experts
,
grid
);
reduce_unaligned_cumsum
(
tokens_cnts_ptr
/*src_and_dest*/
,
smem_ptr
,
local_offsets
,
tid
,
lane_id
,
warp_id
,
num_experts
,
grid
);
align_global_cumsum
(
local_offsets
/*src_and_dest*/
,
local_offsets_buf
,
smem_ptr
,
tid
,
block_size
,
num_experts
);
store_global_cumsum
(
cumsum
/*dest*/
,
total_tokens_post_pad
/*dest*/
,
local_offsets
/*src*/
,
tid
,
num_experts
,
grid
);
if
(
tid
<
num_experts
)
{
for
(
int
i
=
local_offsets
[
tid
];
i
<
local_offsets
[
tid
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
tid
;
}
}
__syncthreads
();
if
(
threadIdx
.
x
*
tokens_per_thread
<
tokens_per_block
)
{
for
(
int
i
=
start_idx
;
i
<
MIN
(
numel
,
end_idx
);
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
}
...
...
@@ -302,29 +95,22 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
num_blocks
=
MIN
(
CEILDIV
(
topk_ids
.
sizes
()[
0
],
block_threads
),
num_experts
);
scalar_t
*
topk_ids_ptr
=
topk_ids
.
data_ptr
<
scalar_t
>
();
int32_t
*
sorted_token_ids_ptr
=
sorted_token_ids
.
data_ptr
<
int32_t
>
();
int32_t
*
experts_ids_ptr
=
experts_ids
.
data_ptr
<
int32_t
>
();
int32_t
*
num_tokens_post_pad_ptr
=
num_tokens_post_pad
.
data_ptr
<
int32_t
>
();
size_t
num_tokens
=
topk_ids
.
numel
();
int32_t
*
token_cnts_buffer_ptr
=
token_cnts_buffer
.
data_ptr
<
int32_t
>
();
int32_t
*
cumsum_buffer_ptr
=
cumsum_buffer
.
data_ptr
<
int32_t
>
();
int
tokens_per_block
=
CEILDIV
(
topk_ids
.
sizes
()[
0
],
num_blocks
)
*
topk_ids
.
sizes
()[
1
];
int
tokens_per_thread
=
CEILDIV
(
tokens_per_block
,
block_threads
);
int
K
=
topk_ids
.
sizes
()[
1
];
void
*
kernelArgs
[]
=
{
&
topk_ids_ptr
,
&
sorted_token_ids_ptr
,
&
experts_ids_ptr
,
&
num_tokens_post_pad_ptr
,
&
num_experts
,
&
block_size
,
&
num_tokens
,
&
token_cnts_buffer_ptr
,
&
cumsum_buffer_ptr
,
&
tokens_per_block
,
&
tokens_per_thread
,
&
K
};
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
num_blocks
,
block_threads
,
kernelArgs
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
}
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
18bb216c
...
...
@@ -49,17 +49,6 @@ struct cuda_error : public std::runtime_error {
} \
} while (0)
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
template
<
typename
T
>
void
check
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
{
if
(
result
)
{
fprintf
(
stderr
,
"CUDA error at %s:%d code=%d(%s)
\"
%s
\"
\n
"
,
file
,
line
,
static_cast
<
unsigned
int
>
(
result
),
cudaGetErrorString
(
result
),
func
);
cudaDeviceReset
();
exit
(
EXIT_FAILURE
);
}
}
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
...
...
@@ -106,22 +95,3 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize // 64
#endif
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_runtime.h>
static
__inline__
__host__
__device__
hipError_t
cudaLaunchCooperativeKernel
(
const
void
*
f
,
dim3
gridDim
,
dim3
blockDimX
,
void
**
kernelParams
)
{
return
hipLaunchCooperativeKernel
(
f
,
gridDim
,
blockDimX
,
kernelParams
,
0
,
hipStreamDefault
);
}
#endif
sgl-kernel/tests/test_moe_align.py
View file @
18bb216c
...
...
@@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to
num_tokens_post_pad_cuda
=
torch
.
empty
(
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
zeros
(
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
cumsum_buffer
=
torch
.
zeros
(
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment