Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4e81540
Unverified
Commit
f4e81540
authored
Oct 27, 2025
by
Jee Jee Li
Committed by
GitHub
Oct 27, 2025
Browse files
[Kernel] Enable moe LoRA kernel support FP16 (#27468)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
a663f6ae
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
16 deletions
+26
-16
tests/lora/test_fused_moe_lora_kernel.py
tests/lora/test_fused_moe_lora_kernel.py
+17
-6
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+9
-10
No files found.
tests/lora/test_fused_moe_lora_kernel.py
View file @
f4e81540
...
...
@@ -204,6 +204,11 @@ def use_torch(
return
torch
.
stack
(
outputs
,
dim
=
0
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DEVICES
=
[
f
"cuda:
{
0
}
"
]
SEED
=
[
42
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"top_k_num"
,
[
6
,
12
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
64
])
...
...
@@ -212,6 +217,9 @@ def use_torch(
@
pytest
.
mark
.
parametrize
(
"K"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"max_lora_rank"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
def
test_fused_moe_lora_kernel
(
num_tokens
,
top_k_num
,
...
...
@@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
K
,
max_lora_rank
,
block_size
,
dtype
,
device
,
seed
,
):
torch
.
set_default_device
(
"cuda:0"
)
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
# the number of randomly generated sentences.
num_sequences
=
10
# generate data
...
...
@@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
max_lora_rank
,
K
,
),
dtype
=
torch
.
bfloat16
,
dtype
=
dtype
,
)
]
lora_b_stacked
=
[
...
...
@@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
N
,
max_lora_rank
,
),
dtype
=
torch
.
bfloat16
,
dtype
=
dtype
,
)
]
hidden_states
=
torch
.
rand
(
...
...
@@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
num_tokens
,
K
,
),
dtype
=
torch
.
bfloat16
,
dtype
=
dtype
,
)
# fused_moe_lora_kernel output
output
=
torch
.
zeros
((
num_tokens
,
top_k_num
,
N
),
dtype
=
torch
.
bfloat16
)
output
=
torch
.
zeros
((
num_tokens
,
top_k_num
,
N
),
dtype
=
dtype
)
use_fused_moe_lora_kernel
(
topk_ids
,
topk_weights
,
...
...
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
f4e81540
...
...
@@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
triton
import
triton.language
as
tl
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
_LORA_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
...
...
@@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
# get a_ptr,b_ptr,c_ptr
cur_a_ptr
=
a_ptr
+
(
slice_id
%
num_slice_a
)
*
slice_a_size
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
bfloat16
))
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
cur_c_ptr
=
c_ptr
+
(
slice_id
%
num_slice_c
)
*
slice_c_size
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
...
...
@@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
tl
.
bfloat16
)
accumulator
=
accumulator
.
to
(
c_ptr
.
dtype
.
element_ty
)
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
cur_c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
...
...
@@ -205,6 +204,10 @@ def _fused_moe_lora(
assert
output
.
shape
[
0
]
==
topk_weights
.
shape
[
0
]
assert
top_k_num
==
topk_weights
.
shape
[
1
]
for
lora_a
,
lora_b
in
zip
(
lora_a_stacked
,
lora_b_stacked
):
assert
lora_a
.
dtype
==
lora_b
.
dtype
==
output
.
dtype
==
qcurr_hidden_states
.
dtype
assert
lora_a
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
device
=
qcurr_hidden_states
.
device
num_slices
=
len
(
lora_a_stacked
)
...
...
@@ -227,9 +230,9 @@ def _fused_moe_lora(
num_tokens
=
M
*
top_k_num
w1_output_dim_size
=
w1_lora_b_stacked
.
shape
[
2
]
lora_intermediate_cache1
=
torch
.
zeros
(
lora_intermediate_cache1
=
torch
.
empty
(
(
num_slices
*
M
*
top_k_num
*
(
max_lora_rank
+
w1_output_dim_size
)),
dtype
=
torch
.
bfloat16
,
dtype
=
output
.
dtype
,
device
=
device
,
)
...
...
@@ -288,10 +291,6 @@ def _fused_moe_lora(
K
=
max_lora_rank
N
=
w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1
=
a_intermediate_cache1
.
view
(
-
1
,
a_intermediate_cache1
.
shape
[
3
]
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment