Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3bbb2046
Unverified
Commit
3bbb2046
authored
Feb 24, 2026
by
Xin Yang
Committed by
GitHub
Feb 24, 2026
Browse files
[Bugfix] Fix expert_ids padding values in moe_align_block_size kernel (#35161)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
576fe503
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
9 deletions
+12
-9
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+4
-4
tests/kernels/moe/test_moe_align_block_size.py
tests/kernels/moe/test_moe_align_block_size.py
+8
-5
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
3bbb2046
...
@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size(
...
@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size(
}
}
}
}
// Fill remaining expert_ids with
0
// Fill remaining expert_ids with
-1
const
size_t
fill_start_idx
=
const
size_t
fill_start_idx
=
cumsum
[
cumsum_offset
+
num_experts
]
/
block_size
+
threadIdx
.
x
;
cumsum
[
cumsum_offset
+
num_experts
]
/
block_size
+
threadIdx
.
x
;
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
blockDim
.
x
)
{
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
blockDim
.
x
)
{
...
@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
...
@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
}
}
}
}
// Fill remaining expert_ids with
0
// Fill remaining expert_ids with
-1
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
tid
;
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
tid
;
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
stride
)
{
for
(
size_t
i
=
fill_start_idx
;
i
<
max_num_m_blocks
;
i
+=
stride
)
{
expert_ids
[
expert_ids_offset
+
i
]
=
inactive_expert_id
;
expert_ids
[
expert_ids_offset
+
i
]
=
inactive_expert_id
;
...
@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel(
...
@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
numel
,
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
numel
,
cumsum
,
max_num_tokens_padded
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
cumsum
,
max_num_tokens_padded
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
0
,
0
,
topk_num
,
nullptr
,
has_expert_map
);
0
,
-
1
,
topk_num
,
nullptr
,
has_expert_map
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
_moe_align_block_size_small_batch_expert
<
scalar_t
,
fill_threads
>
(
_moe_align_block_size_small_batch_expert
<
scalar_t
,
fill_threads
>
(
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
total_tokens_post_pad
,
expert_map
,
num_experts
,
block_size
,
numel
,
max_num_tokens_padded
,
num_experts
,
block_size
,
numel
,
max_num_tokens_padded
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
0
,
0
,
topk_num
,
nullptr
,
CEILDIV
(
max_num_tokens_padded
,
block_size
),
-
1
,
0
,
topk_num
,
nullptr
,
has_expert_map
);
has_expert_map
);
}
}
...
...
tests/kernels/moe/test_moe_align_block_size.py
View file @
3bbb2046
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size
,
batched_moe_align_block_size
,
moe_align_block_size
,
moe_align_block_size
,
)
)
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
NUM_TOKENS
=
[
1
,
3
,
256
,
2256
,
4096
]
NUM_TOKENS
=
[
1
,
3
,
256
,
2256
,
4096
]
...
@@ -142,7 +142,9 @@ def torch_moe_align_block_size(
...
@@ -142,7 +142,9 @@ def torch_moe_align_block_size(
device
=
topk_ids
.
device
,
device
=
topk_ids
.
device
,
)
)
max_num_blocks
=
(
max_num_tokens_padded
+
block_size
-
1
)
//
block_size
max_num_blocks
=
(
max_num_tokens_padded
+
block_size
-
1
)
//
block_size
expert_ids
=
torch
.
zeros
(
max_num_blocks
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
expert_ids
=
torch
.
full
(
(
max_num_blocks
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
current_pos
=
0
current_pos
=
0
current_block
=
0
current_block
=
0
...
@@ -234,9 +236,10 @@ def test_moe_align_block_size(
...
@@ -234,9 +236,10 @@ def test_moe_align_block_size(
assert
len
(
valid_tokens
)
==
total_tokens
,
(
assert
len
(
valid_tokens
)
==
total_tokens
,
(
f
"Should have exactly
{
total_tokens
}
valid tokens, got
{
len
(
valid_tokens
)
}
"
f
"Should have exactly
{
total_tokens
}
valid tokens, got
{
len
(
valid_tokens
)
}
"
)
)
assert
(
actual_expert_ids
>=
0
).
all
()
and
(
actual_expert_ids
<
num_experts
).
all
(),
(
actual_num_blocks
=
cdiv
(
int
(
actual_num_tokens
.
item
()),
block_size
)
"expert_ids should contain valid expert indices"
assert
(
actual_expert_ids
[:
actual_num_blocks
]
>=
0
).
all
()
and
(
)
actual_expert_ids
[:
actual_num_blocks
]
<
num_experts
).
all
(),
"expert_ids should contain valid expert indices"
@
pytest
.
mark
.
parametrize
(
"m"
,
[
16
,
32
,
2048
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
16
,
32
,
2048
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment