Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8cd174fa
Unverified
Commit
8cd174fa
authored
Apr 26, 2026
by
Jee Jee Li
Committed by
GitHub
Apr 26, 2026
Browse files
[LoRA] MoE LoRA Refactor (#40338)
parent
c798593f
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
736 additions
and
328 deletions
+736
-328
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+36
-297
vllm/lora/layers/utils.py
vllm/lora/layers/utils.py
+5
-4
vllm/lora/ops/triton_ops/utils.py
vllm/lora/ops/triton_ops/utils.py
+17
-0
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+62
-0
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+236
-0
vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py
...or/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py
+52
-3
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+106
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+48
-1
vllm/model_executor/layers/fused_moe/lora_context.py
vllm/model_executor/layers/fused_moe/lora_context.py
+44
-0
vllm/model_executor/layers/fused_moe/lora_experts_mixin.py
vllm/model_executor/layers/fused_moe/lora_experts_mixin.py
+111
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+14
-0
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+0
-3
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+0
-13
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+5
-1
No files found.
vllm/lora/layers/fused_moe.py
View file @
8cd174fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
torch
import
torch.nn
as
nn
...
...
@@ -14,31 +13,17 @@ from vllm.distributed.parallel_state import (
)
from
vllm.distributed.utils
import
divide
from
vllm.lora.layers.base
import
BaseLayerWithLoRA
from
vllm.lora.ops.triton_ops.utils
import
get_lora_op_configs
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.config
import
(
_get_config_dtype_str
,
)
from
vllm.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe
import
(
UnfusedOAITritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEKernel
,
)
from
vllm.model_executor.layers.fused_moe.lora_context
import
MoELoRAContext
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEKernel
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoDPEPModular
,
)
from
.utils
import
_get_lora_device
,
try_get_optimal_moe_lora_config
from
.utils
import
_get_lora_device
class
FusedMoEWithLoRA
(
BaseLayerWithLoRA
):
...
...
@@ -58,299 +43,49 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self
.
_w13_slices
=
2
if
base_layer
.
moe_config
.
is_act_and_mul
else
1
self
.
_inject_lora_into_fused_moe
()
def
_normalize_keys
(
self
,
config
:
dict
[
str
,
int
|
None
])
->
dict
[
str
,
int
|
None
]:
normalized_config
=
{}
for
key
,
value
in
config
.
items
():
if
key
.
islower
():
if
key
.
startswith
(
"block_"
):
normalized_key
=
"BLOCK_SIZE_"
+
key
.
split
(
"_"
)[
-
1
].
upper
()
else
:
normalized_key
=
key
.
upper
()
else
:
normalized_key
=
key
normalized_config
[
normalized_key
]
=
value
return
normalized_config
def
_get_lora_moe_configs
(
self
,
op_prefix
:
str
,
num_loras
:
int
,
rank
:
int
,
num_slices
:
int
,
M
:
int
,
layer
:
FusedMoE
,
top_k
:
int
,
config_dtype
:
str
,
):
if
envs
.
VLLM_TUNED_CONFIG_FOLDER
:
hidden_size
=
layer
.
hidden_size
intermediate_size
=
(
self
.
w2_lora_a_stacked
[
0
].
shape
[
-
1
]
if
op_prefix
==
"w2"
else
self
.
w13_lora_b_stacked
[
0
].
shape
[
-
2
]
)
shrink_config
=
get_lora_op_configs
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_shrink"
,
max_loras
=
num_loras
,
batch
=
M
,
hidden_size
=
hidden_size
,
rank
=
rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
intermediate_size
,
)
expand_config
=
get_lora_op_configs
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_expand"
,
max_loras
=
num_loras
,
batch
=
M
,
hidden_size
=
hidden_size
,
# lora_a_stacked.shape[-1],
rank
=
rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
intermediate_size
,
# lora_b_stacked.shape[-2],
)
else
:
# fall back to the default config
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_lora_config
,
w1_shape
=
layer
.
w13_weight
.
shape
,
w2_shape
=
layer
.
w2_weight
.
shape
,
rank
=
rank
,
top_k
=
top_k
,
dtype
=
config_dtype
,
M
=
M
,
block_shape
=
layer
.
quant_method
.
moe_quant_config
.
block_shape
,
)
shrink_config
=
get_config_func
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_shrink"
)
expand_config
=
get_config_func
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_expand"
)
shrink_config
=
self
.
_normalize_keys
(
shrink_config
)
expand_config
=
self
.
_normalize_keys
(
expand_config
)
return
shrink_config
,
expand_config
def
_inject_lora_into_fused_moe
(
self
):
moe_state_dict
=
{}
top_k
=
self
.
base_layer
.
top_k
self
.
base_layer
.
ensure_moe_quant_config_init
()
quant_config
=
self
.
base_layer
.
quant_method
.
moe_quant_config
if
getattr
(
self
.
base_layer
.
quant_method
,
"supports_internal_mk"
,
False
):
# Use the existing modular kernel from the quant method
m_fused_moe_fn
=
self
.
base_layer
.
quant_method
.
moe_kernel
moe_kernel
=
self
.
base_layer
.
quant_method
.
moe_kernel
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
m
_fused_moe_fn
.
shared_experts
=
None
m
oe_kernel
.
shared_experts
=
None
else
:
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream.
prepare_finalize
=
MoEPrepareAndFinalizeNoDPEPModular
()
m
_fused_moe_fn
=
FusedMoEKernel
(
m
oe_kernel
=
FusedMoEKernel
(
prepare_finalize
,
self
.
base_layer
.
quant_method
.
select_gemm_impl
(
prepare_finalize
,
self
.
base_layer
),
)
if
quant_config
.
use_mxfp4_w4a16
:
assert
isinstance
(
m_fused_moe_fn
.
impl
.
fused_experts
,
(
MarlinExperts
,
UnfusedOAITritonExperts
),
)
else
:
assert
isinstance
(
m_fused_moe_fn
.
impl
.
fused_experts
,
TritonExperts
)
def
fwd_decorator
(
layer
,
func
):
def
wrapper
(
*
args
,
**
kwargs
):
moe_state_dict
[
"hidden_states"
]
=
kwargs
[
"hidden_states"
]
moe_state_dict
[
"topk_ids"
]
=
kwargs
[
"topk_ids"
]
moe_state_dict
[
"topk_weights"
]
=
kwargs
[
"topk_weights"
]
moe_state_dict
[
"expert_map"
]
=
kwargs
[
"expert_map"
]
moe_state_dict
[
"apply_router_weight_on_input"
]
=
kwargs
[
"apply_router_weight_on_input"
]
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
def
act_decorator
(
layer
,
func
):
def
wrapper
(
*
args
,
**
kwargs
):
_
,
output
,
input
=
args
hidden_states
=
moe_state_dict
[
"hidden_states"
]
topk_weights
=
moe_state_dict
[
"topk_weights"
]
curr_topk_ids
=
moe_state_dict
[
"topk_ids"
]
expert_map
=
moe_state_dict
[
"expert_map"
]
config_dtype
=
_get_config_dtype_str
(
dtype
=
hidden_states
.
dtype
,
use_fp8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
)
num_tokens
=
hidden_states
.
size
(
0
)
M
=
num_tokens
max_lora_rank
=
self
.
w13_lora_a_stacked
[
0
].
shape
[
-
2
]
shrink_config
,
expand_config
=
self
.
_get_lora_moe_configs
(
op_prefix
=
"w13"
,
num_loras
=
self
.
max_loras
,
rank
=
max_lora_rank
,
num_slices
=
self
.
_w13_slices
,
M
=
M
,
layer
=
layer
,
top_k
=
top_k
,
config_dtype
=
config_dtype
,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
# activates only a small fraction of total experts * loras.
SPARSITY_FACTOR
=
8
naive_block_assignment
=
(
expert_map
is
None
and
num_tokens
*
top_k
*
SPARSITY_FACTOR
<=
self
.
base_layer
.
local_num_experts
*
self
.
max_loras
)
# get the block size of m from customized config or default config
(
token_lora_mapping
,
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
)
=
self
.
punica_wrapper
.
moe_lora_align_block_size
(
curr_topk_ids
,
num_tokens
,
shrink_config
[
"BLOCK_SIZE_M"
],
self
.
base_layer
.
local_num_experts
,
self
.
max_loras
,
self
.
adapter_enabled
,
expert_map
,
naive_block_assignment
=
naive_block_assignment
,
)
moe_state_dict
[
"sorted_token_ids_lora"
]
=
sorted_token_ids_lora
moe_state_dict
[
"expert_ids_lora"
]
=
expert_ids_lora
moe_state_dict
[
"num_tokens_post_padded_lora"
]
=
(
num_tokens_post_padded_lora
)
moe_state_dict
[
"token_lora_mapping"
]
=
token_lora_mapping
if
sorted_token_ids_lora
is
not
None
:
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
#
self
.
punica_wrapper
.
add_lora_fused_moe
(
input
.
view
(
-
1
,
top_k
,
input
.
shape
[
-
1
]),
hidden_states
,
self
.
w13_lora_a_stacked
,
self
.
w13_lora_b_stacked
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
max_lora_rank
,
top_k
,
shrink_config
,
## pass the shrink config
expand_config
,
## pass the expand config
self
.
adapter_enabled
,
fully_sharded
=
self
.
fully_sharded
,
token_lora_mapping
=
token_lora_mapping
,
)
result
=
func
(
*
args
,
**
kwargs
)
moe_state_dict
[
"intermediate_cache2"
]
=
output
return
result
return
wrapper
def
moe_sum_decorator
(
layer
,
func
):
def
wrapper
(
*
args
,
**
kwargs
):
hidden_states
=
moe_state_dict
[
"hidden_states"
]
topk_weights
=
moe_state_dict
[
"topk_weights"
]
config_dtype
=
_get_config_dtype_str
(
dtype
=
hidden_states
.
dtype
,
use_fp8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
)
num_tokens
=
hidden_states
.
size
(
0
)
M
=
num_tokens
max_lora_rank
=
self
.
w2_lora_a_stacked
[
0
].
shape
[
-
2
]
shrink_config
,
expand_config
=
self
.
_get_lora_moe_configs
(
op_prefix
=
"w2"
,
num_loras
=
self
.
max_loras
,
rank
=
max_lora_rank
,
num_slices
=
1
,
M
=
M
,
layer
=
layer
,
top_k
=
top_k
,
config_dtype
=
config_dtype
,
)
sorted_token_ids_lora
=
moe_state_dict
[
"sorted_token_ids_lora"
]
expert_ids_lora
=
moe_state_dict
[
"expert_ids_lora"
]
num_tokens_post_padded_lora
=
moe_state_dict
[
"num_tokens_post_padded_lora"
]
token_lora_mapping
=
moe_state_dict
.
get
(
"token_lora_mapping"
)
if
sorted_token_ids_lora
is
not
None
:
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
intermediate_cache2
=
moe_state_dict
[
"intermediate_cache2"
]
intermediate_cache3
=
args
[
0
]
shard_size_w2
=
divide
(
self
.
base_layer
.
hidden_size
,
self
.
tp_size
)
self
.
punica_wrapper
.
add_lora_fused_moe
(
intermediate_cache3
,
intermediate_cache2
,
self
.
w2_lora_a_stacked
,
self
.
w2_lora_b_stacked
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
max_lora_rank
,
top_k
,
shrink_config
,
## pass the shrink config
expand_config
,
## pass the expand config
self
.
adapter_enabled
,
True
,
fully_sharded
=
self
.
fully_sharded
,
offset
=
shard_size_w2
*
self
.
tp_rank
if
self
.
fully_sharded
else
0
,
token_lora_mapping
=
token_lora_mapping
,
)
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
fused_experts
=
m_fused_moe_fn
.
impl
.
fused_experts
m_fused_moe_fn
.
apply
=
fwd_decorator
(
self
.
base_layer
,
m_fused_moe_fn
.
apply
)
fused_experts
.
activation
=
act_decorator
(
self
.
base_layer
,
fused_experts
.
activation
)
fused_experts
.
moe_sum
=
moe_sum_decorator
(
self
.
base_layer
,
fused_experts
.
moe_sum
assert
moe_kernel
.
supports_lora
(),
(
f
"
{
type
(
moe_kernel
.
fused_experts
).
__name__
}
does not support LoRA. "
"For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
"(auto selects Triton automatically when LoRA is enabled). "
"For quantized MoE, mix LoRAExpertsMixin into the experts class "
"and consume self._lora_context in apply()."
)
# TODO(bnell): find a less intrusive way to handle this.
self
.
_fused_experts
=
moe_kernel
.
fused_experts
self
.
base_layer
.
_replace_quant_method
(
FusedMoEModularMethod
(
self
.
base_layer
.
quant_method
,
m_fused_moe_fn
)
FusedMoEModularMethod
(
self
.
base_layer
.
quant_method
,
moe_kernel
)
)
def
_build_lora_context
(
self
):
return
MoELoRAContext
(
w13_lora_a_stacked
=
self
.
w13_lora_a_stacked
,
w13_lora_b_stacked
=
self
.
w13_lora_b_stacked
,
w2_lora_a_stacked
=
self
.
w2_lora_a_stacked
,
w2_lora_b_stacked
=
self
.
w2_lora_b_stacked
,
adapter_enabled
=
self
.
adapter_enabled
,
max_loras
=
self
.
max_loras
,
top_k
=
self
.
base_layer
.
top_k
,
w13_num_slices
=
self
.
_w13_slices
,
fully_sharded
=
self
.
fully_sharded
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
local_num_experts
=
self
.
base_layer
.
local_num_experts
,
punica_wrapper
=
self
.
punica_wrapper
,
use_tuned_config
=
bool
(
envs
.
VLLM_TUNED_CONFIG_FOLDER
),
)
def
_create_lora_a_weights
(
...
...
@@ -589,6 +324,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index
,
:,
:
sliced_w2_lora_b
.
shape
[
1
],
:
sliced_w2_lora_b
.
shape
[
2
]
].
copy_
(
sliced_w2_lora_b
,
non_blocking
=
True
)
def
set_mapping
(
self
,
punica_wrapper
):
super
().
set_mapping
(
punica_wrapper
)
self
.
_fused_experts
.
set_lora_context
(
self
.
_build_lora_context
())
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
forward
(
*
args
,
**
kwargs
)
...
...
vllm/lora/layers/utils.py
View file @
8cd174fa
...
...
@@ -90,11 +90,12 @@ def try_get_optimal_moe_lora_config(
top_k
:
int
,
dtype
:
str
|
None
,
M
:
int
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
dict
[
str
,
int
|
None
]:
config
=
try_get_optimal_moe_config
(
w1_shape
,
w2_shape
,
top_k
,
dtype
,
M
,
block_shape
).
copy
()
# LoRA shrink/expand operates on bf16/fp16 adapters regardless of the
# base MoE weight's block-wise quantization, so block_shape is omitted
# from the config lookup — the non-quantized branch in get_default_config
# ignores it anyway.
config
=
try_get_optimal_moe_config
(
w1_shape
,
w2_shape
,
top_k
,
dtype
,
M
).
copy
()
if
op_type
in
[
"fused_moe_lora_w13_shrink"
,
"fused_moe_lora_w2_shrink"
,
...
...
vllm/lora/ops/triton_ops/utils.py
View file @
8cd174fa
...
...
@@ -321,3 +321,20 @@ def supports_pdl(device: torch.device | None = None) -> bool:
def
supports_tma
(
device
:
torch
.
device
|
None
=
None
)
->
bool
:
# TMA requires compute capability SM90 or above
return
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)
def
_normalize_lora_config_keys
(
config
:
dict
[
str
,
int
|
None
],
)
->
dict
[
str
,
int
|
None
]:
"""Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format."""
out
:
dict
[
str
,
int
|
None
]
=
{}
for
key
,
val
in
config
.
items
():
if
key
.
islower
():
if
key
.
startswith
(
"block_"
):
nk
=
"BLOCK_SIZE_"
+
key
.
split
(
"_"
)[
-
1
].
upper
()
else
:
nk
=
key
.
upper
()
else
:
nk
=
key
out
[
nk
]
=
val
return
out
vllm/lora/punica_wrapper/punica_base.py
View file @
8cd174fa
...
...
@@ -493,3 +493,65 @@ class PunicaWrapperBase(PunicaWrapperABC):
"""
# TODO: implement it based on torch ops
raise
NotImplementedError
def
add_lora_w13
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
num_tokens
:
int
,
top_k_num
:
int
,
max_loras
:
int
,
adapter_enabled
:
torch
.
Tensor
,
local_num_experts
:
int
,
top_k
:
int
,
num_slices
:
int
,
fully_sharded
:
bool
,
use_tuned_config
:
bool
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
"""Apply w13 LoRA to y (intermediate_cache1) in-place before activation.
Returns (sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora, token_lora_mapping)
for reuse by add_lora_w2.
"""
raise
NotImplementedError
def
add_lora_w2
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_weights
:
torch
.
Tensor
,
sorted_token_ids_lora
:
torch
.
Tensor
|
None
,
expert_ids_lora
:
torch
.
Tensor
|
None
,
num_tokens_post_padded_lora
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
|
None
,
num_tokens
:
int
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
top_k_num
:
int
,
max_loras
:
int
,
adapter_enabled
:
torch
.
Tensor
,
top_k
:
int
,
fully_sharded
:
bool
,
tp_rank
:
int
,
use_tuned_config
:
bool
,
)
->
None
:
"""Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum.
Reuses routing tensors returned by add_lora_w13.
"""
raise
NotImplementedError
vllm/lora/punica_wrapper/punica_gpu.py
View file @
8cd174fa
...
...
@@ -459,3 +459,239 @@ class PunicaWrapperGPU(PunicaWrapperBase):
fully_sharded
,
offset
,
)
def
add_lora_w13
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
num_tokens
:
int
,
top_k_num
:
int
,
max_loras
:
int
,
adapter_enabled
:
torch
.
Tensor
,
local_num_experts
:
int
,
top_k
:
int
,
num_slices
:
int
,
fully_sharded
:
bool
,
use_tuned_config
:
bool
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
import
functools
from
vllm.lora.layers.utils
import
try_get_optimal_moe_lora_config
from
vllm.lora.ops.triton_ops.utils
import
(
_normalize_lora_config_keys
,
get_lora_op_configs
,
)
from
vllm.model_executor.layers.fused_moe.config
import
_get_config_dtype_str
config_dtype
=
_get_config_dtype_str
(
dtype
=
x
.
dtype
,
use_fp8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
)
max_lora_rank
=
lora_a_stacked
[
0
].
shape
[
-
2
]
if
use_tuned_config
:
shrink_config
=
get_lora_op_configs
(
op_type
=
"fused_moe_lora_w13_shrink"
,
max_loras
=
max_loras
,
batch
=
num_tokens
,
hidden_size
=
x
.
shape
[
-
1
],
rank
=
max_lora_rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
lora_b_stacked
[
0
].
shape
[
-
2
],
)
expand_config
=
get_lora_op_configs
(
op_type
=
"fused_moe_lora_w13_expand"
,
max_loras
=
max_loras
,
batch
=
num_tokens
,
hidden_size
=
x
.
shape
[
-
1
],
rank
=
max_lora_rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
lora_b_stacked
[
0
].
shape
[
-
2
],
)
else
:
get_config
=
functools
.
partial
(
try_get_optimal_moe_lora_config
,
w1_shape
=
w1
.
shape
,
w2_shape
=
w2
.
shape
,
rank
=
max_lora_rank
,
top_k
=
top_k
,
dtype
=
config_dtype
,
M
=
num_tokens
,
)
shrink_config
=
get_config
(
op_type
=
"fused_moe_lora_w13_shrink"
)
expand_config
=
get_config
(
op_type
=
"fused_moe_lora_w13_expand"
)
shrink_config
=
_normalize_lora_config_keys
(
shrink_config
)
expand_config
=
_normalize_lora_config_keys
(
expand_config
)
SPARSITY_FACTOR
=
8
naive_block_assignment
=
(
expert_map
is
None
and
num_tokens
*
top_k
*
SPARSITY_FACTOR
<=
local_num_experts
*
max_loras
)
(
token_lora_mapping
,
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
)
=
self
.
moe_lora_align_block_size
(
topk_ids
,
num_tokens
,
int
(
shrink_config
.
get
(
"BLOCK_SIZE_M"
)
or
64
),
local_num_experts
,
max_loras
,
adapter_enabled
,
expert_map
,
naive_block_assignment
=
naive_block_assignment
,
)
_sorted
=
sorted_token_ids_lora
_eids
=
expert_ids_lora
if
_sorted
is
not
None
:
_eids
=
_eids
.
view
(
max_loras
,
-
1
)
_sorted
=
_sorted
.
view
(
max_loras
,
-
1
)
self
.
add_lora_fused_moe
(
y
.
view
(
-
1
,
top_k_num
,
y
.
shape
[
-
1
]),
x
,
lora_a_stacked
,
lora_b_stacked
,
topk_weights
,
_sorted
,
_eids
,
num_tokens_post_padded_lora
,
max_lora_rank
,
top_k
,
shrink_config
,
expand_config
,
adapter_enabled
,
fully_sharded
=
fully_sharded
,
token_lora_mapping
=
token_lora_mapping
,
)
return
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
token_lora_mapping
,
)
def
add_lora_w2
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...],
topk_weights
:
torch
.
Tensor
,
sorted_token_ids_lora
:
torch
.
Tensor
|
None
,
expert_ids_lora
:
torch
.
Tensor
|
None
,
num_tokens_post_padded_lora
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
|
None
,
num_tokens
:
int
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
top_k_num
:
int
,
max_loras
:
int
,
adapter_enabled
:
torch
.
Tensor
,
top_k
:
int
,
fully_sharded
:
bool
,
tp_rank
:
int
,
use_tuned_config
:
bool
,
)
->
None
:
import
functools
from
vllm.lora.layers.utils
import
try_get_optimal_moe_lora_config
from
vllm.lora.ops.triton_ops.utils
import
(
_normalize_lora_config_keys
,
get_lora_op_configs
,
)
from
vllm.model_executor.layers.fused_moe.config
import
_get_config_dtype_str
config_dtype
=
_get_config_dtype_str
(
dtype
=
x
.
dtype
,
use_fp8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
)
max_lora_rank
=
lora_a_stacked
[
0
].
shape
[
-
2
]
if
use_tuned_config
:
shrink_config
=
get_lora_op_configs
(
op_type
=
"fused_moe_lora_w2_shrink"
,
max_loras
=
max_loras
,
batch
=
num_tokens
,
hidden_size
=
y
.
shape
[
-
1
],
rank
=
max_lora_rank
,
num_slices
=
1
,
moe_intermediate_size
=
lora_a_stacked
[
0
].
shape
[
-
1
],
)
expand_config
=
get_lora_op_configs
(
op_type
=
"fused_moe_lora_w2_expand"
,
max_loras
=
max_loras
,
batch
=
num_tokens
,
hidden_size
=
y
.
shape
[
-
1
],
rank
=
max_lora_rank
,
num_slices
=
1
,
moe_intermediate_size
=
lora_a_stacked
[
0
].
shape
[
-
1
],
)
else
:
get_config
=
functools
.
partial
(
try_get_optimal_moe_lora_config
,
w1_shape
=
w1
.
shape
,
w2_shape
=
w2
.
shape
,
rank
=
max_lora_rank
,
top_k
=
top_k
,
dtype
=
config_dtype
,
M
=
num_tokens
,
)
shrink_config
=
get_config
(
op_type
=
"fused_moe_lora_w2_shrink"
)
expand_config
=
get_config
(
op_type
=
"fused_moe_lora_w2_expand"
)
shrink_config
=
_normalize_lora_config_keys
(
shrink_config
)
expand_config
=
_normalize_lora_config_keys
(
expand_config
)
_sorted
=
sorted_token_ids_lora
_eids
=
expert_ids_lora
if
_sorted
is
not
None
:
assert
_eids
is
not
None
_eids
=
_eids
.
view
(
max_loras
,
-
1
)
_sorted
=
_sorted
.
view
(
max_loras
,
-
1
)
# w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded
shard_size
=
lora_b_stacked
[
0
].
shape
[
-
2
]
offset
=
shard_size
*
tp_rank
if
fully_sharded
else
0
self
.
add_lora_fused_moe
(
y
,
x
,
lora_a_stacked
,
lora_b_stacked
,
topk_weights
,
_sorted
,
_eids
,
num_tokens_post_padded_lora
,
max_lora_rank
,
top_k
,
shrink_config
,
expand_config
,
adapter_enabled
,
True
,
# mul_routed_weight
fully_sharded
=
fully_sharded
,
offset
=
offset
,
token_lora_mapping
=
token_lora_mapping
,
)
vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py
View file @
8cd174fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
...
...
@@ -16,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.lora_experts_mixin
import
LoRAExpertsMixin
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
...
...
@@ -654,7 +654,7 @@ class OAITritonExperts(BaseOAITritonExperts):
)
class
UnfusedOAITritonExperts
(
BaseOAITritonExperts
):
class
UnfusedOAITritonExperts
(
LoRAExpertsMixin
,
BaseOAITritonExperts
):
"""
A Triton based MoE expert class that operates on expert standard
format and explicitly keeps the activation and reduction (moe_sum) steps
...
...
@@ -721,6 +721,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
global_topk_ids
=
topk_ids
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
...
...
@@ -775,10 +776,40 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
y
=
intermediate_cache1
,
)
# w13 LoRA: gather the activation input from expert-sorted
# intermediate_cache1, then add the LoRA delta in-place on that copy
# before passing it to activation — exactly mirroring the old
# decorator approach which modified the gathered tensor in-place.
act_input
=
intermediate_cache1
.
view
(
-
1
,
N
)[
gather_indx
.
dst_indx
]
sorted_token_ids_lora
=
None
expert_ids_lora
=
None
num_tokens_post_padded_lora
=
None
token_lora_mapping
=
None
lora_context
=
self
.
_lora_context
if
lora_context
is
not
None
:
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
token_lora_mapping
,
)
=
self
.
apply_w13_lora
(
lora_context
,
y
=
act_input
,
x
=
hidden_states
,
topk_ids
=
global_topk_ids
,
topk_weights
=
topk_weights
,
expert_map
=
expert_map
,
w1
=
w1
,
w2
=
w2
,
num_tokens
=
M
,
top_k_num
=
topk
,
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)[
gather_indx
.
dst_indx
]
,
act_input
,
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
...
...
@@ -797,6 +828,24 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
y
=
intermediate_cache3
,
)
# w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is
# in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects.
if
lora_context
is
not
None
:
self
.
apply_w2_lora
(
lora_context
,
y
=
intermediate_cache3
.
view
(
-
1
,
topk
,
K
),
x
=
intermediate_cache2
,
topk_weights
=
topk_weights
,
sorted_token_ids_lora
=
sorted_token_ids_lora
,
expert_ids_lora
=
expert_ids_lora
,
num_tokens_post_padded_lora
=
num_tokens_post_padded_lora
,
token_lora_mapping
=
token_lora_mapping
,
num_tokens
=
M
,
w1
=
w1
,
w2
=
w2
,
top_k_num
=
topk
,
)
self
.
moe_sum
(
intermediate_cache3
.
view
(
-
1
,
topk
,
K
),
output
)
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
8cd174fa
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.lora_experts_mixin
import
LoRAExpertsMixin
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
batched_moe_align_block_size
,
moe_align_block_size
,
...
...
@@ -655,7 +656,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
return
E
,
M
,
N
,
K
,
topk
class
MarlinExperts
(
MarlinExpertsBase
):
class
MarlinExperts
(
LoRAExpertsMixin
,
MarlinExpertsBase
):
"""Marlin-based fused MoE expert implementation."""
def
supports_expert_map
(
self
)
->
bool
:
...
...
@@ -720,7 +721,108 @@ class MarlinExperts(MarlinExpertsBase):
):
assert
self
.
w1_scale
is
not
None
assert
self
.
w2_scale
is
not
None
fused_marlin_moe
(
ctx
=
self
.
_lora_context
if
ctx
is
None
:
fused_marlin_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
bias1
=
self
.
w1_bias
,
bias2
=
self
.
w2_bias
,
w1_scale
=
self
.
w1_scale
,
w2_scale
=
self
.
w2_scale
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
global_scale1
=
self
.
g1_alphas
,
global_scale2
=
self
.
g2_alphas
,
quant_type_id
=
self
.
quant_type_id
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
activation
=
activation
,
activation_func
=
self
.
activation
,
moe_sum
=
self
.
moe_sum
,
expert_map
=
expert_map
,
output
=
output
,
# Workspaces are swapped in workspace_shapes() to account for proper
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13
=
workspace2
,
intermediate_cache2
=
workspace13
,
g_idx1
=
self
.
w13_g_idx
,
g_idx2
=
self
.
w2_g_idx
,
sort_indices1
=
self
.
w13_g_idx_sort_indices
,
sort_indices2
=
self
.
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
input_dtype
=
self
.
input_dtype
,
)
return
# LoRA path: wrap activation_func and moe_sum to inject LoRA at the
# two natural injection points.
#
# Marlin uses moe_align_block_size (same as TritonExperts) so
# intermediate_cache1 is indexed by flat (token, expert) pair index,
# which is compatible with add_lora_fused_moe's scatter mechanism.
M
=
hidden_states
.
size
(
0
)
top_k_num
=
topk_ids
.
size
(
1
)
lora_state
:
dict
=
{}
def
activation_with_lora
(
act_enum
:
MoEActivation
,
act_output
:
torch
.
Tensor
,
act_input
:
torch
.
Tensor
,
)
->
None
:
# act_input = intermediate_cache1 (M*topk, 2N for gated)
# act_output = intermediate_cache2 (M*topk, N)
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
token_lora_mapping
,
)
=
self
.
apply_w13_lora
(
ctx
,
y
=
act_input
,
x
=
hidden_states
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
expert_map
=
expert_map
,
w1
=
w1
,
w2
=
w2
,
num_tokens
=
M
,
top_k_num
=
top_k_num
,
)
lora_state
.
update
(
{
"sorted"
:
sorted_token_ids_lora
,
"eids"
:
expert_ids_lora
,
"npad"
:
num_tokens_post_padded_lora
,
"tlm"
:
token_lora_mapping
,
}
)
self
.
activation
(
act_enum
,
act_output
,
act_input
)
lora_state
[
"cache2"
]
=
act_output
def
moe_sum_with_lora
(
moe_out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
# moe_out shape: (M, topk, K)
self
.
apply_w2_lora
(
ctx
,
y
=
moe_out
,
x
=
lora_state
[
"cache2"
],
topk_weights
=
topk_weights
,
sorted_token_ids_lora
=
lora_state
[
"sorted"
],
expert_ids_lora
=
lora_state
[
"eids"
],
num_tokens_post_padded_lora
=
lora_state
[
"npad"
],
token_lora_mapping
=
lora_state
[
"tlm"
],
num_tokens
=
M
,
w1
=
w1
,
w2
=
w2
,
top_k_num
=
top_k_num
,
)
self
.
moe_sum
(
moe_out
,
out
)
return
fused_marlin_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
...
...
@@ -736,12 +838,10 @@ class MarlinExperts(MarlinExpertsBase):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
activation
=
activation
,
activation_func
=
self
.
activation
,
moe_sum
=
self
.
moe_sum
,
activation_func
=
activation
_with_lora
,
moe_sum
=
moe_sum
_with_lora
,
expert_map
=
expert_map
,
output
=
output
,
# Workspaces are swapped in workspace_shapes() to account for proper
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13
=
workspace2
,
intermediate_cache2
=
workspace13
,
g_idx1
=
self
.
w13_g_idx
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
8cd174fa
...
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
_get_config_dtype_str
,
)
from
vllm.model_executor.layers.fused_moe.lora_experts_mixin
import
LoRAExpertsMixin
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
...
...
@@ -1886,7 +1887,7 @@ def fused_experts_impl(
return
out_hidden_states
class
TritonExperts
(
mk
.
FusedMoEExpertsModular
):
class
TritonExperts
(
LoRAExpertsMixin
,
mk
.
FusedMoEExpertsModular
):
"""Triton-based fused MoE expert implementation."""
def
__init__
(
...
...
@@ -2094,6 +2095,33 @@ class TritonExperts(mk.FusedMoEExpertsModular):
B_bias
=
self
.
w1_bias
,
)
# LoRA w13: applied to intermediate_cache1 before activation, using
# hidden_states as the lora_a input. moe_lora_align_block_size is
# called once here and results reused for the w2 LoRA below.
sorted_token_ids_lora
=
None
expert_ids_lora
=
None
num_tokens_post_padded_lora
=
None
token_lora_mapping
=
None
lora_context
=
self
.
_lora_context
if
lora_context
is
not
None
:
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
token_lora_mapping
,
)
=
self
.
apply_w13_lora
(
lora_context
,
y
=
intermediate_cache1
,
x
=
hidden_states
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
expert_map
=
expert_map
,
w1
=
w1
,
w2
=
w2
,
num_tokens
=
num_tokens
,
top_k_num
=
top_k_num
,
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
)
...
...
@@ -2132,6 +2160,25 @@ class TritonExperts(mk.FusedMoEExpertsModular):
B_bias
=
self
.
w2_bias
,
)
# LoRA w2: applied to intermediate_cache3 before moe_sum, using the
# unquantized intermediate_cache2 as the lora_a input. Reuses the
# sorted_token_ids_lora computed above.
if
lora_context
is
not
None
:
self
.
apply_w2_lora
(
lora_context
,
y
=
intermediate_cache3
,
x
=
intermediate_cache2
,
topk_weights
=
topk_weights
,
sorted_token_ids_lora
=
sorted_token_ids_lora
,
expert_ids_lora
=
expert_ids_lora
,
num_tokens_post_padded_lora
=
num_tokens_post_padded_lora
,
token_lora_mapping
=
token_lora_mapping
,
num_tokens
=
num_tokens
,
w1
=
w1
,
w2
=
w2
,
top_k_num
=
top_k_num
,
)
# separate function is required for MoE + LoRA
self
.
moe_sum
(
intermediate_cache3
,
output
)
...
...
vllm/model_executor/layers/fused_moe/lora_context.py
0 → 100644
View file @
8cd174fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.lora.punica_wrapper.punica_base
import
PunicaWrapperBase
@
dataclass
class
MoELoRAContext
:
"""
Carries all LoRA state for one MoE forward pass.
Built by FusedMoEWithLoRA.forward() and propagated explicitly through the
modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so
that TritonExperts.apply() can compute the LoRA contribution inline,
replacing the decorator-based monkey-patch approach.
"""
# LoRA weight tensors (same shapes as FusedMoEWithLoRA attributes)
w13_lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...]
w13_lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...]
w2_lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...]
w2_lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...]
# (max_loras + 1,) int32; slot 0 is the "no-adapter" sentinel
adapter_enabled
:
torch
.
Tensor
# Metadata
max_loras
:
int
top_k
:
int
w13_num_slices
:
int
# 2 = gated (gate + up), 1 = non-gated or 3D-fused
fully_sharded
:
bool
tp_rank
:
int
tp_size
:
int
local_num_experts
:
int
punica_wrapper
:
PunicaWrapperBase
# Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs
# try_get_optimal_moe_lora_config for Triton kernel tile configs.
use_tuned_config
:
bool
vllm/model_executor/layers/fused_moe/lora_experts_mixin.py
0 → 100644
View file @
8cd174fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.layers.fused_moe.lora_context
import
MoELoRAContext
class
LoRAExpertsMixin
:
"""
Mixin for FusedMoEExpertsModular subclasses that natively handle
MoELoRAContext inside their apply() implementation.
Mixing this class in:
- Flips supports_lora() to True so _can_fused_experts_support lets
LoRA through the gate check.
- Stashes a MoELoRAContext on the experts instance via
set_lora_context(), which apply() consumes from self._lora_context.
- Provides apply_w13_lora / apply_w2_lora helpers that dispatch to
the PunicaWrapper kernels.
The helper methods are pure functions of their inputs; all required
state is on lora_context or passed as arguments.
"""
_lora_context
:
MoELoRAContext
|
None
=
None
def
set_lora_context
(
self
,
ctx
:
MoELoRAContext
)
->
None
:
self
.
_lora_context
=
ctx
@
staticmethod
def
supports_lora
()
->
bool
:
return
True
def
apply_w13_lora
(
self
,
lora_context
:
MoELoRAContext
,
*
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
num_tokens
:
int
,
top_k_num
:
int
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
return
lora_context
.
punica_wrapper
.
add_lora_w13
(
y
,
x
,
lora_context
.
w13_lora_a_stacked
,
lora_context
.
w13_lora_b_stacked
,
topk_ids
,
topk_weights
,
expert_map
,
w1
,
w2
,
num_tokens
,
top_k_num
,
lora_context
.
max_loras
,
lora_context
.
adapter_enabled
,
lora_context
.
local_num_experts
,
lora_context
.
top_k
,
lora_context
.
w13_num_slices
,
lora_context
.
fully_sharded
,
lora_context
.
use_tuned_config
,
)
def
apply_w2_lora
(
self
,
lora_context
:
MoELoRAContext
,
*
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
sorted_token_ids_lora
:
torch
.
Tensor
|
None
,
expert_ids_lora
:
torch
.
Tensor
|
None
,
num_tokens_post_padded_lora
:
torch
.
Tensor
|
None
,
token_lora_mapping
:
torch
.
Tensor
|
None
,
num_tokens
:
int
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
top_k_num
:
int
,
)
->
None
:
lora_context
.
punica_wrapper
.
add_lora_w2
(
y
,
x
,
lora_context
.
w2_lora_a_stacked
,
lora_context
.
w2_lora_b_stacked
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
,
token_lora_mapping
,
num_tokens
,
w1
,
w2
,
top_k_num
,
lora_context
.
max_loras
,
lora_context
.
adapter_enabled
,
lora_context
.
top_k
,
lora_context
.
fully_sharded
,
lora_context
.
tp_rank
,
lora_context
.
use_tuned_config
,
)
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
8cd174fa
...
...
@@ -570,6 +570,8 @@ class FusedMoEExperts(ABC):
return
False
,
_make_reason
(
f
"
{
activation_format
.
value
}
activation format"
)
elif
envs
.
VLLM_BATCH_INVARIANT
and
not
cls
.
_supports_batch_invariance
():
return
False
,
_make_reason
(
"batch invariance"
)
elif
moe_config
.
is_lora_enabled
and
not
cls
.
supports_lora
():
return
False
,
_make_reason
(
"LoRA"
)
return
True
,
None
@
staticmethod
...
...
@@ -734,6 +736,15 @@ class FusedMoEExperts(ABC):
def
g2_alphas
(
self
)
->
torch
.
Tensor
|
None
:
return
self
.
quant_config
.
g2_alphas
@
staticmethod
def
supports_lora
()
->
bool
:
"""Return True if this expert impl natively handles LoRA.
LoRA-aware experts should mix in LoRAExpertsMixin, which flips this
to True and provides the per-forward LoRA state plumbing.
"""
return
False
@
abstractmethod
def
supports_expert_map
(
self
)
->
bool
:
"""
...
...
@@ -1527,6 +1538,9 @@ class FusedMoEKernel:
def
fused_experts
(
self
)
->
FusedMoEExperts
:
return
self
.
impl
.
fused_experts
def
supports_lora
(
self
)
->
bool
:
return
self
.
fused_experts
.
supports_lora
()
def
_post_init_setup
(
self
):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
8cd174fa
...
...
@@ -220,9 +220,6 @@ def select_fp8_moe_backend(
Note: Shape-specific fallbacks may still occur at runtime.
"""
if
config
.
is_lora_enabled
:
return
Fp8MoeBackend
.
TRITON
,
backend_to_kernel_cls
(
Fp8MoeBackend
.
TRITON
)[
0
]
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS
=
_get_priority_backends
(
config
,
weight_key
,
activation_key
)
...
...
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
View file @
8cd174fa
...
...
@@ -214,19 +214,6 @@ def select_unquantized_moe_backend(
return
backend
,
k_cls
raise
ValueError
(
_make_log_unsupported
(
backend
,
reason
))
# LoRA needs Triton's unfused activation/reduction hooks. Selecting the
# backend here ensures weights stay in a LoRA-compatible layout instead of
# being permuted for a backend like FlashInfer or AITER during load.
if
moe_config
.
is_lora_enabled
:
backend
=
UnquantizedMoeBackend
.
TRITON
if
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
:
backend
=
UnquantizedMoeBackend
.
BATCHED_TRITON
return
_return_or_raise
(
backend
,
moe_config
,
activation_format
,
)
runner_backend
=
moe_config
.
moe_backend
if
runner_backend
!=
"auto"
:
requested_backend
=
map_unquantized_backend
(
runner_backend
)
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
8cd174fa
...
...
@@ -297,7 +297,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
layer
,
x
,
topk_weights
,
topk_ids
,
shared_experts_input
layer
,
x
,
topk_weights
,
topk_ids
,
shared_experts_input
,
)
def
apply_monolithic
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment