Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ded9fcd0
Unverified
Commit
ded9fcd0
authored
Jan 06, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 06, 2025
Browse files
improve moe_align_kernel for deepseek v3 (#2735)
parent
bc6ad367
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
51 deletions
+32
-51
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+32
-50
sgl-kernel/tests/test_trt_reduce.py
sgl-kernel/tests/test_trt_reduce.py
+0
-1
No files found.
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
ded9fcd0
...
...
@@ -46,74 +46,61 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
tokens_cnts
,
int32_t
*
cumsum
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
experts_per_warp
=
8
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
shared_counts
[
warp_id
][
i
]
=
0
;
}
}
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
][
expert_offset
],
1
);
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
][
expert_offset
];
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
local_offsets
[
threadIdx
.
x
]
=
cumsum
[
threadIdx
.
x
];
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
__syncthreads
();
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
...
...
@@ -122,14 +109,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
auto
kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
kernel
<<<
1
,
num_thread
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
token_cnts_buffer
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
}
sgl-kernel/tests/test_trt_reduce.py
View file @
ded9fcd0
...
...
@@ -36,7 +36,6 @@ def multi_process_parallel(
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment