Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1489902b
Unverified
Commit
1489902b
authored
Nov 22, 2025
by
Jee Jee Li
Committed by
GitHub
Nov 22, 2025
Browse files
[LoRA] Cleanup FusedMoEWithLoRA (#29187)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
933f67ec
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
99 deletions
+94
-99
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+90
-95
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+2
-2
vllm/lora/punica_wrapper/punica_gpu.py
vllm/lora/punica_wrapper/punica_gpu.py
+2
-2
No files found.
vllm/lora/layers/fused_moe.py
View file @
1489902b
...
...
@@ -42,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
device
=
base_layer
.
w2_weight
.
device
self
.
w13_slices
=
2
self
.
_inject_lora_into_fused_moe
()
def
_normalize_keys
(
self
,
config
:
dict
[
str
,
int
|
None
])
->
dict
[
str
,
int
|
None
]:
...
...
@@ -60,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def
_get_lora_moe_configs
(
self
,
op_prefix
:
str
,
lora_a_stacked
:
torch
.
Tensor
,
lora_b_stacked
:
torch
.
Tensor
,
num_loras
:
int
,
rank
:
int
,
num_slices
:
int
,
M
:
int
,
layer
:
FusedMoE
,
...
...
@@ -69,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype
:
str
,
):
if
envs
.
VLLM_TUNED_CONFIG_FOLDER
:
hidden_size
=
layer
.
hidden_size
intermediate_size
=
layer
.
intermediate_size_per_partition
shrink_config
=
get_lora_op_configs
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_shrink"
,
max_loras
=
lora_a_stacked
.
shape
[
0
]
,
max_loras
=
num_loras
,
batch
=
M
,
hidden_size
=
lora_a_stacked
.
shape
[
-
1
]
,
rank
=
lora_a_stacked
.
shape
[
-
2
]
,
hidden_size
=
hidden_size
,
rank
=
rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
lora_b_stacked
.
shape
[
-
2
]
,
moe_intermediate_size
=
intermediate_size
,
)
expand_config
=
get_lora_op_configs
(
op_type
=
f
"fused_moe_lora_
{
op_prefix
}
_expand"
,
max_loras
=
lora_a_stacked
.
shape
[
0
]
,
max_loras
=
num_loras
,
batch
=
M
,
hidden_size
=
lora_a_stacked
.
shape
[
-
1
],
rank
=
lora_a_stacked
.
shape
[
-
2
]
,
hidden_size
=
hidden_size
,
#
lora_a_stacked.shape[-1],
rank
=
rank
,
num_slices
=
num_slices
,
moe_intermediate_size
=
lora_b_stacked
.
shape
[
-
2
],
moe_intermediate_size
=
intermediate_size
,
#
lora_b_stacked.shape[-2],
)
else
:
# fall back to the default config
get_config_func
=
functools
.
partial
(
...
...
@@ -152,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens
=
hidden_states
.
size
(
0
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
max_lora_rank
=
self
.
w13_lora_a_stacked
[
0
].
shape
[
-
2
]
shrink_config
,
expand_config
=
self
.
_get_lora_moe_configs
(
op_prefix
=
"w13"
,
lora_a_stacked
=
self
.
w1
_lora
_a_stacked
,
lora_b_stacked
=
self
.
w1_lora_b_stacked
,
num_slices
=
2
,
num_loras
=
self
.
max
_lora
s
,
rank
=
max_lora_rank
,
num_slices
=
self
.
w13_slices
,
M
=
M
,
layer
=
layer
,
top_k
=
top_k
,
...
...
@@ -165,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
# get the block size of m from customized config or default config
max_loras
=
self
.
w1_lora_a_stacked
.
shape
[
0
]
(
sorted_token_ids_lora
,
expert_ids_lora
,
...
...
@@ -175,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens
,
shrink_config
[
"BLOCK_SIZE_M"
],
self
.
base_layer
.
local_num_experts
,
max_loras
,
self
.
max_loras
,
self
.
adapter_enabled
,
expert_map
,
)
...
...
@@ -186,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora
)
w13_lora_a_stacked
=
[
self
.
w1_lora_a_stacked
,
self
.
w3_lora_a_stacked
]
w13_lora_b_stacked
=
[
self
.
w1_lora_b_stacked
,
self
.
w3_lora_b_stacked
]
max_lora_rank
=
self
.
w1_lora_a_stacked
.
shape
[
-
2
]
expert_ids_lora
=
expert_ids_lora
.
view
(
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
max_loras
,
-
1
)
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
#
self
.
punica_wrapper
.
add_lora_fused_moe
(
input
.
view
(
-
1
,
top_k
,
input
.
shape
[
-
1
]),
hidden_states
,
w13_lora_a_stacked
,
w13_lora_b_stacked
,
self
.
w13_lora_a_stacked
,
self
.
w13_lora_b_stacked
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
...
...
@@ -230,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens
=
hidden_states
.
size
(
0
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
max_lora_rank
=
self
.
w2_lora_a_stacked
.
shape
[
-
2
]
shrink_config
,
expand_config
=
self
.
_get_lora_moe_configs
(
op_prefix
=
"w2"
,
lora_a_stacked
=
self
.
w2
_lora
_a_stacked
,
lora_b_stacked
=
self
.
w2_lora_b_stacked
,
num_loras
=
self
.
max
_lora
s
,
rank
=
max_lora_rank
,
num_slices
=
1
,
M
=
M
,
layer
=
layer
,
...
...
@@ -247,20 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora
=
moe_state_dict
[
"num_tokens_post_padded_lora"
]
max_loras
=
self
.
w1_lora_a_stacked
.
shape
[
0
]
expert_ids_lora
=
expert_ids_lora
.
view
(
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
max_loras
,
-
1
)
expert_ids_lora
=
expert_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
sorted_token_ids_lora
=
sorted_token_ids_lora
.
view
(
self
.
max_loras
,
-
1
)
intermediate_cache2
=
moe_state_dict
[
"intermediate_cache2"
]
intermediate_cache3
=
args
[
0
]
max_lora_rank
=
self
.
w2_lora_a_stacked
.
shape
[
-
2
]
shard_size_w2
=
divide
(
self
.
base_layer
.
hidden_size
,
self
.
tp_size
)
self
.
punica_wrapper
.
add_lora_fused_moe
(
intermediate_cache3
,
intermediate_cache2
,
[
self
.
w2_lora_a_stacked
]
,
[
self
.
w2_lora_b_stacked
]
,
(
self
.
w2_lora_a_stacked
,)
,
(
self
.
w2_lora_b_stacked
,)
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
...
...
@@ -289,7 +288,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts
.
moe_sum
=
moe_sum_decorator
(
self
.
base_layer
,
fused_experts
.
moe_sum
)
self
.
base_layer
.
quant_method
=
FusedMoEModularMethod
(
self
.
base_layer
.
quant_method
,
m_fused_moe_fn
)
...
...
@@ -301,33 +299,42 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
None
:
"""Initializes lora matrices."""
assert
self
.
w13_slices
==
2
self
.
max_loras
=
lora_config
.
max_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
self
.
adapter_enabled
=
torch
.
tensor
(
[
0
]
*
(
max_loras
+
1
),
dtype
=
torch
.
int
,
device
=
self
.
device
)
self
.
w1_lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
lora_config
.
max_lora_rank
if
not
self
.
fully_sharded
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
),
self
.
base_layer
.
hidden_size
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
self
.
w13_lora_a_stacked
=
tuple
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
lora_config
.
max_lora_rank
if
not
self
.
fully_sharded
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
),
self
.
base_layer
.
hidden_size
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
w13_slices
)
)
self
.
w1_lora_b_stacked
=
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
intermediate_size_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
self
.
w13_lora_b_stacked
=
tuple
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
intermediate_size_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
w13_slices
)
)
self
.
w2_lora_a_stacked
=
torch
.
zeros
(
...
...
@@ -353,29 +360,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
device
=
self
.
device
,
)
self
.
w3_lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
lora_config
.
max_lora_rank
if
not
self
.
fully_sharded
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
),
self
.
base_layer
.
hidden_size
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
w3_lora_b_stacked
=
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
intermediate_size_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
self
.
lora_a_stacked
=
[]
...
...
@@ -383,20 +367,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
for
lora_id
in
range
(
max_loras
):
for
experts_id
in
range
(
self
.
base_layer
.
local_num_experts
):
# gate_proj,down_proj,up_proj
self
.
lora_a_stacked
.
append
(
self
.
w1_lora_a_stacked
[
lora_id
][
experts_id
])
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_a_stacked
.
append
(
self
.
w2_lora_a_stacked
[
lora_id
][
experts_id
])
self
.
lora_a_stacked
.
append
(
self
.
w3_lora_a_stacked
[
lora_id
][
experts_id
])
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
1
][
lora_id
][
experts_id
]
)
self
.
lora_b_stacked
.
append
(
self
.
w1_lora_b_stacked
[
lora_id
][
experts_id
])
self
.
lora_b_stacked
.
append
(
self
.
w13_lora_b_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_b_stacked
.
append
(
self
.
w2_lora_b_stacked
[
lora_id
][
experts_id
])
self
.
lora_b_stacked
.
append
(
self
.
w3_lora_b_stacked
[
lora_id
][
experts_id
])
self
.
lora_b_stacked
.
append
(
self
.
w13_lora_b_stacked
[
1
][
lora_id
][
experts_id
]
)
def
reset_lora
(
self
,
index
:
int
):
"""Resets the lora weights at index back to 0."""
self
.
w1_lora_a_stacked
[
index
]
=
0
self
.
w1_lora_
b
_stacked
[
index
]
=
0
self
.
w3_lora_
a
_stacked
[
index
]
=
0
self
.
w3_lora_b_stacked
[
index
]
=
0
for
pos
in
range
(
self
.
w13_slices
):
self
.
w1
3
_lora_
a
_stacked
[
pos
][
index
]
=
0
self
.
w
1
3_lora_
b
_stacked
[
pos
][
index
]
=
0
self
.
w2_lora_a_stacked
[
index
]
=
0
self
.
w2_lora_b_stacked
[
index
]
=
0
self
.
adapter_enabled
[
index
]
=
0
...
...
@@ -434,7 +426,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if
self
.
fully_sharded
:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
w13_shard_size
=
self
.
w1_lora_a_stacked
[
index
,
eid
].
shape
[
0
]
w13_shard_size
=
self
.
w1
3
_lora_a_stacked
[
0
][
index
,
eid
].
shape
[
0
]
w13_start_idx
=
self
.
tp_rank
*
w13_shard_size
w13_end_idx
=
(
self
.
tp_rank
+
1
)
*
w13_shard_size
w1_lora_a
=
w1_lora_a
[
w13_start_idx
:
w13_end_idx
,
:]
...
...
@@ -444,29 +436,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w2_start_idx
=
self
.
tp_rank
*
w2_shard_size
w2_end_idx
=
(
self
.
tp_rank
+
1
)
*
w2_shard_size
w2_lora_b
=
w2_lora_b
[
w2_start_idx
:
w2_end_idx
,
:]
self
.
w1_lora_a_stacked
[
# w1 lora_a
self
.
w1
3
_lora_a_stacked
[
0
][
index
,
eid
,
:
w1_lora_a
.
shape
[
0
],
:
w1_lora_a
.
shape
[
1
]
].
copy_
(
w1_lora_a
,
non_blocking
=
True
)
self
.
w3_lora_a_stacked
[
# w3 lora_a
self
.
w
1
3_lora_a_stacked
[
1
][
index
,
eid
,
:
w3_lora_a
.
shape
[
0
],
:
w3_lora_a
.
shape
[
1
]
].
copy_
(
w3_lora_a
,
non_blocking
=
True
)
self
.
w2_lora_b_stacked
[
index
,
eid
,
:
w2_lora_b
.
shape
[
0
],
:
w2_lora_b
.
shape
[
1
]
].
copy_
(
w2_lora_b
,
non_blocking
=
True
)
self
.
w1_lora_b_stacked
[
# w1 lora_b
self
.
w13_lora_b_stacked
[
0
][
index
,
eid
,
:
w1_lora_b
.
shape
[
0
],
:
w1_lora_b
.
shape
[
1
]
].
copy_
(
w1_lora_b
,
non_blocking
=
True
)
self
.
w3_lora_b_stacked
[
# w3 lora_b
self
.
w13_lora_b_stacked
[
1
][
index
,
eid
,
:
w3_lora_b
.
shape
[
0
],
:
w3_lora_b
.
shape
[
1
]
].
copy_
(
w3_lora_b
,
non_blocking
=
True
)
self
.
w2_lora_a_stacked
[
index
,
eid
,
:
w2_lora_a
.
shape
[
0
],
:
w2_lora_a
.
shape
[
1
]
].
copy_
(
w2_lora_a
,
non_blocking
=
True
)
self
.
w2_lora_b_stacked
[
index
,
eid
,
:
w2_lora_b
.
shape
[
0
],
:
w2_lora_b
.
shape
[
1
]
].
copy_
(
w2_lora_b
,
non_blocking
=
True
)
@
classmethod
def
can_replace_layer
(
cls
,
...
...
vllm/lora/punica_wrapper/punica_base.py
View file @
1489902b
...
...
@@ -470,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
list
[
torch
.
Tensor
],
lora_b_stacked
:
list
[
torch
.
Tensor
],
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...
],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...
],
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
...
...
vllm/lora/punica_wrapper/punica_gpu.py
View file @
1489902b
...
...
@@ -360,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
list
[
torch
.
Tensor
],
lora_b_stacked
:
list
[
torch
.
Tensor
],
lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...
],
lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...
],
topk_weights
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment