Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2344192a
Unverified
Commit
2344192a
authored
Feb 13, 2025
by
Michael Goin
Committed by
GitHub
Feb 13, 2025
Browse files
Optimize moe_align_block_size for deepseek_v3 (#12850)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
bffddd9a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
16 deletions
+39
-16
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+37
-15
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-1
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
2344192a
...
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
...
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
}
}
// taken from
// taken from
// https://github.com/sgl-project/sglang/commit/
ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
// https://github.com/sgl-project/sglang/commit/
cdae77b03dfc6fec3863630550b45bbfc789f957
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
sgl_moe_align_block_size_kernel
(
__global__
void
sgl_moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
experts_per_warp
=
8
;
const
int
experts_per_warp
=
8
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
// Initialize shared_counts for this warp's experts
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
shared_counts
[
warp_id
][
i
]
=
0
;
shared_counts
[
warp_id
][
i
]
=
0
;
}
}
}
}
__syncthreads
();
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
...
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
...
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads
();
__syncthreads
();
// Single thread computes cumulative sum and total tokens
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
...
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
...
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads
();
__syncthreads
();
// Assign expert IDs to blocks
if
(
threadIdx
.
x
<
num_experts
)
{
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
local_offsets
[
threadIdx
.
x
]
=
cumsum
[
threadIdx
.
x
];
}
}
}
__syncthreads
();
// taken from
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
template
<
typename
scalar_t
>
__global__
void
sgl_moe_token_sort_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
in
t
i
=
start_
id
x
;
i
<
numel
&&
i
<
st
art_idx
+
tokens_per_thread
;
++
i
)
{
for
(
size_
t
i
=
t
id
;
i
<
numel
;
i
+=
st
ride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
int32_t
rank_post_pad
=
atomicAdd
(
&
cumsum_buffer
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
}
}
...
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// calc needed amount of shared mem for `cumsum` tensors
// tensors
auto
options_int
=
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
torch
::
Tensor
cumsum_buffer
=
torch
::
Tensor
cumsum_buffer
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
torch
::
zeros
({
num_experts
+
1
},
options_int
);
auto
kernel
=
vllm
::
moe
::
sgl_moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
vllm
::
moe
::
sgl_moe_align_block_size_kernel
<
scalar_t
>
;
align_kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
vllm
::
moe
::
sgl_moe_token_sort_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
});
});
}
}
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
2344192a
...
@@ -596,7 +596,7 @@ def moe_align_block_size(
...
@@ -596,7 +596,7 @@ def moe_align_block_size(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
moe_align_block_size_triton
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
...
@@ -606,6 +606,7 @@ def moe_align_block_size(
...
@@ -606,6 +606,7 @@ def moe_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
)
)
else
:
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
ops
.
sgl_moe_align_block_size
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment