Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a7c9f7b7
Unverified
Commit
a7c9f7b7
authored
Feb 23, 2026
by
Xin Yang
Committed by
GitHub
Feb 23, 2026
Browse files
[Bugfix] Fix lora_ids in FusedMoE LoRA test (#35135)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
a4bd661f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
10 deletions
+19
-10
tests/lora/test_fused_moe_lora_kernel.py
tests/lora/test_fused_moe_lora_kernel.py
+19
-10
No files found.
tests/lora/test_fused_moe_lora_kernel.py
View file @
a7c9f7b7
...
@@ -118,7 +118,10 @@ def sample_data(
...
@@ -118,7 +118,10 @@ def sample_data(
num_tokens
,
num_experts
,
top_k_num
num_tokens
,
num_experts
,
top_k_num
)
)
token_lora_mapping
=
assign_loras_to_tokens
(
num_tokens
,
num_sequences
,
max_loras
)
token_lora_mapping
=
assign_loras_to_tokens
(
num_tokens
,
num_sequences
,
max_loras
)
return
topk_ids
,
topk_weights
,
token_lora_mapping
active_lora_ids
=
torch
.
full
((
max_loras
+
1
,),
-
1
,
dtype
=
torch
.
int32
)
lora_ids
=
torch
.
unique
(
token_lora_mapping
,
sorted
=
True
)
active_lora_ids
[:
lora_ids
.
size
(
0
)].
copy_
(
lora_ids
,
non_blocking
=
True
)
return
topk_ids
,
topk_weights
,
token_lora_mapping
,
active_lora_ids
def
use_fused_moe_lora_kernel
(
def
use_fused_moe_lora_kernel
(
...
@@ -127,6 +130,7 @@ def use_fused_moe_lora_kernel(
...
@@ -127,6 +130,7 @@ def use_fused_moe_lora_kernel(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_a_stacked
,
lora_a_stacked
,
lora_b_stacked
,
lora_b_stacked
,
hidden_states
,
hidden_states
,
...
@@ -149,7 +153,6 @@ def use_fused_moe_lora_kernel(
...
@@ -149,7 +153,6 @@ def use_fused_moe_lora_kernel(
expert_ids
=
torch
.
empty
((
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
)
expert_ids
=
torch
.
empty
((
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
)
num_tokens_post_padded
=
torch
.
empty
((
max_loras
,),
dtype
=
torch
.
int32
)
num_tokens_post_padded
=
torch
.
empty
((
max_loras
,),
dtype
=
torch
.
int32
)
adapter_enabled
=
torch
.
ones
(
max_loras
+
1
,
dtype
=
torch
.
int32
)
adapter_enabled
=
torch
.
ones
(
max_loras
+
1
,
dtype
=
torch
.
int32
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
)
# call kernel
# call kernel
ops
.
moe_lora_align_block_size
(
ops
.
moe_lora_align_block_size
(
...
@@ -168,7 +171,7 @@ def use_fused_moe_lora_kernel(
...
@@ -168,7 +171,7 @@ def use_fused_moe_lora_kernel(
)
)
config
=
{
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
block_size
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
...
@@ -275,7 +278,7 @@ def test_fused_moe_lora_kernel(
...
@@ -275,7 +278,7 @@ def test_fused_moe_lora_kernel(
# the number of randomly generated sentences.
# the number of randomly generated sentences.
num_sequences
=
10
num_sequences
=
10
# generate data
# generate data
topk_ids
,
topk_weights
,
token_lora_mapping
=
sample_data
(
topk_ids
,
topk_weights
,
token_lora_mapping
,
lora_ids
=
sample_data
(
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
)
)
...
@@ -318,6 +321,7 @@ def test_fused_moe_lora_kernel(
...
@@ -318,6 +321,7 @@ def test_fused_moe_lora_kernel(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_a_stacked
,
lora_a_stacked
,
lora_b_stacked
,
lora_b_stacked
,
hidden_states
,
hidden_states
,
...
@@ -336,7 +340,7 @@ def test_fused_moe_lora_kernel(
...
@@ -336,7 +340,7 @@ def test_fused_moe_lora_kernel(
top_k_num
,
top_k_num
,
)
)
torch
.
testing
.
assert_close
(
output
,
output2
,
atol
=
1e-
1
,
rtol
=
1e-
1
)
torch
.
testing
.
assert_close
(
output
,
output2
,
atol
=
1e-
2
,
rtol
=
1e-
2
)
def
use_fused_moe_lora_kernel_naive
(
def
use_fused_moe_lora_kernel_naive
(
...
@@ -345,6 +349,7 @@ def use_fused_moe_lora_kernel_naive(
...
@@ -345,6 +349,7 @@ def use_fused_moe_lora_kernel_naive(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_a_stacked
,
lora_a_stacked
,
lora_b_stacked
,
lora_b_stacked
,
hidden_states
,
hidden_states
,
...
@@ -379,7 +384,6 @@ def use_fused_moe_lora_kernel_naive(
...
@@ -379,7 +384,6 @@ def use_fused_moe_lora_kernel_naive(
num_tokens_post_padded
=
None
num_tokens_post_padded
=
None
adapter_enabled
=
torch
.
ones
(
max_loras
+
1
,
dtype
=
torch
.
int32
)
adapter_enabled
=
torch
.
ones
(
max_loras
+
1
,
dtype
=
torch
.
int32
)
lora_ids
=
torch
.
arange
(
max_loras
+
2
,
dtype
=
torch
.
int32
)
# num_active_loras is the number of active LoRAs
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
# (max_loras + 1 to include no-lora case)
...
@@ -463,7 +467,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
...
@@ -463,7 +467,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
# the number of randomly generated sentences.
# the number of randomly generated sentences.
num_sequences
=
min
(
num_tokens
,
4
)
num_sequences
=
min
(
num_tokens
,
4
)
# generate data
# generate data
topk_ids
,
topk_weights
,
token_lora_mapping
=
sample_data
(
topk_ids
,
topk_weights
,
token_lora_mapping
,
lora_ids
=
sample_data
(
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
)
)
...
@@ -506,6 +510,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
...
@@ -506,6 +510,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
lora_a_stacked
,
lora_a_stacked
,
lora_b_stacked
,
lora_b_stacked
,
hidden_states
,
hidden_states
,
...
@@ -524,7 +529,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
...
@@ -524,7 +529,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
top_k_num
,
top_k_num
,
)
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
atol
=
1e-
1
,
rtol
=
1e-
1
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
atol
=
1e-
2
,
rtol
=
1e-
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
multi_gpu_test
(
num_gpus
=
2
)
...
@@ -556,7 +561,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
...
@@ -556,7 +561,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
# the number of randomly generated sentences.
# the number of randomly generated sentences.
num_sequences
=
10
num_sequences
=
10
# generate data
# generate data
topk_ids
,
topk_weights
,
token_lora_mapping
=
sample_data
(
topk_ids
,
topk_weights
,
token_lora_mapping
,
lora_ids
=
sample_data
(
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
num_tokens
,
num_sequences
,
max_loras
,
num_experts
,
top_k_num
)
)
...
@@ -576,6 +581,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
...
@@ -576,6 +581,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
max_loras
,
max_loras
,
num_experts
,
num_experts
,
block_size
,
block_size
,
...
@@ -601,6 +607,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
...
@@ -601,6 +607,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
max_loras
,
max_loras
,
num_experts
,
num_experts
,
block_size
,
block_size
,
...
@@ -660,6 +667,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
...
@@ -660,6 +667,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
topk_ids
=
topk_ids
.
to
(
device
)
topk_ids
=
topk_ids
.
to
(
device
)
topk_weights
=
topk_weights
.
to
(
device
)
topk_weights
=
topk_weights
.
to
(
device
)
token_lora_mapping
=
token_lora_mapping
.
to
(
device
)
token_lora_mapping
=
token_lora_mapping
.
to
(
device
)
lora_ids
=
lora_ids
.
to
(
device
)
ref_output
=
use_torch
(
ref_output
=
use_torch
(
hidden_states
,
hidden_states
,
...
@@ -698,6 +706,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
...
@@ -698,6 +706,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
token_lora_mapping
,
token_lora_mapping
,
max_lora_rank
,
max_lora_rank
,
top_k_num
,
top_k_num
,
lora_ids
,
[
lora_a
],
[
lora_a
],
[
lora_b
],
[
lora_b
],
hidden_states
,
hidden_states
,
...
@@ -714,4 +723,4 @@ def use_fused_moe_lora_kernel_tensor_parallel(
...
@@ -714,4 +723,4 @@ def use_fused_moe_lora_kernel_tensor_parallel(
else
:
else
:
output
=
tensor_model_parallel_all_reduce
(
output
)
output
=
tensor_model_parallel_all_reduce
(
output
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-
1
,
rtol
=
1e-
1
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-
2
,
rtol
=
1e-
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment