Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ecc3dd66
Unverified
Commit
ecc3dd66
authored
Jan 23, 2026
by
Xin Yang
Committed by
GitHub
Jan 24, 2026
Browse files
[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (#32279)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
7e1f10d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+4
-2
No files found.
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
ecc3dd66
...
...
@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
token_ind
=
stride_tl
*
lora_id
+
offs_token_id
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
token_ind
,
token_ind
<
max_loras
*
stride_tl
,
0
sorted_token_ids_ptr
+
token_ind
,
mask
=
token_ind
<
max_loras
*
stride_tl
,
other
=
num_valid_tokens
,
)
token_mask
=
offs_token
<
num_valid_tokens
...
...
@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
.0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
c_ptr
.
dtype
.
element_ty
)
# Write back the block of the output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment