Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cdae77b0
Unverified
Commit
cdae77b0
authored
Feb 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 07, 2025
Browse files
optimize moe_align_kernel cuda (#3347)
parent
adeee152
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
21 deletions
+29
-21
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+2
-2
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+23
-15
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
View file @
cdae77b0
...
@@ -163,10 +163,10 @@ def calculate_diff(batch_size, seq_len):
...
@@ -163,10 +163,10 @@ def calculate_diff(batch_size, seq_len):
num_tokens_post_pad_cuda
=
torch
.
empty
(
num_tokens_post_pad_cuda
=
torch
.
empty
(
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
token_cnts_buffer
=
torch
.
empty
(
token_cnts_buffer
=
torch
.
zeros
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
cumsum_buffer
=
torch
.
empty
(
cumsum_buffer
=
torch
.
zeros
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
...
@@ -260,10 +260,10 @@ def benchmark(batch_size, seq_len, provider):
...
@@ -260,10 +260,10 @@ def benchmark(batch_size, seq_len, provider):
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
empty
(
token_cnts_buffer
=
torch
.
zeros
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
cumsum_buffer
=
torch
.
empty
(
cumsum_buffer
=
torch
.
zeros
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
cdae77b0
...
@@ -417,12 +417,12 @@ def moe_align_block_size(
...
@@ -417,12 +417,12 @@ def moe_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
)
)
else
:
else
:
token_cnts_buffer
=
torch
.
empty
(
token_cnts_buffer
=
torch
.
zeros
(
(
num_experts
+
1
)
*
num_experts
,
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
device
=
topk_ids
.
device
,
)
)
cumsum_buffer
=
torch
.
empty
(
cumsum_buffer
=
torch
.
zeros
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
...
...
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
cdae77b0
...
@@ -24,12 +24,24 @@ limitations under the License.
...
@@ -24,12 +24,24 @@ limitations under the License.
#define WARP_SIZE 32
#define WARP_SIZE 32
template
<
typename
scalar_t
>
__global__
void
moe_token_sort_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum_buffer
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
experts_per_warp
=
8
;
const
int
experts_per_warp
=
8
;
...
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
...
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
local_offsets
[
threadIdx
.
x
]
=
cumsum
[
threadIdx
.
x
];
}
__syncthreads
();
// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
// If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
// kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
// illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
// results in the same issue, and a correct solution has not yet been found.
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
}
}
...
@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
...
@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
moe_token_sort_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment