Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f5f9acd
Unverified
Commit
2f5f9acd
authored
Nov 27, 2025
by
Jee Jee Li
Committed by
GitHub
Nov 27, 2025
Browse files
[LoRA] Continue optimizing MoE LoRA weight loading (#29322)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
cf348c8d
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
236 additions
and
165 deletions
+236
-165
tests/lora/test_lora_checkpoints.py
tests/lora/test_lora_checkpoints.py
+8
-7
tests/lora/test_lora_huggingface.py
tests/lora/test_lora_huggingface.py
+4
-4
vllm/lora/layers/base.py
vllm/lora/layers/base.py
+1
-1
vllm/lora/layers/column_parallel_linear.py
vllm/lora/layers/column_parallel_linear.py
+8
-8
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+115
-103
vllm/lora/layers/logits_processor.py
vllm/lora/layers/logits_processor.py
+1
-1
vllm/lora/layers/replicated_linear.py
vllm/lora/layers/replicated_linear.py
+1
-1
vllm/lora/layers/row_parallel_linear.py
vllm/lora/layers/row_parallel_linear.py
+2
-2
vllm/lora/layers/vocal_parallel_embedding.py
vllm/lora/layers/vocal_parallel_embedding.py
+1
-1
vllm/lora/lora_weights.py
vllm/lora/lora_weights.py
+53
-0
vllm/lora/models.py
vllm/lora/models.py
+27
-23
vllm/lora/utils.py
vllm/lora/utils.py
+8
-9
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+5
-5
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+1
-0
vllm/model_executor/models/qwen3_vl_moe.py
vllm/model_executor/models/qwen3_vl_moe.py
+1
-0
No files found.
tests/lora/test_lora_checkpoints.py
View file @
2f5f9acd
...
@@ -28,12 +28,13 @@ def test_load_checkpoints(
...
@@ -28,12 +28,13 @@ def test_load_checkpoints(
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
embed_padding_modules
=
BaiChuanBaseForCausalLM
.
embedding_padding_modules
embed_padding_modules
=
BaiChuanBaseForCausalLM
.
embedding_padding_modules
expected_lora_
modules
:
list
[
str
]
=
[]
expected_lora_
lst
:
list
[
str
]
=
[]
for
module
in
BAICHUAN_LORA_MODULES
:
for
module
in
BAICHUAN_LORA_MODULES
:
if
module
in
packed_modules_mapping
:
if
module
in
packed_modules_mapping
:
expected_lora_
modules
.
extend
(
packed_modules_mapping
[
module
])
expected_lora_
lst
.
extend
(
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_modules
.
append
(
module
)
expected_lora_lst
.
append
(
module
)
expected_lora_modules
=
set
(
expected_lora_lst
)
if
lora_name
==
"baichuan7B"
:
if
lora_name
==
"baichuan7B"
:
peft_helper
=
PEFTHelper
.
from_local_dir
(
peft_helper
=
PEFTHelper
.
from_local_dir
(
baichuan_lora_files
,
max_position_embeddings
=
4096
baichuan_lora_files
,
max_position_embeddings
=
4096
...
@@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
...
@@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
embed_padding_modules
=
BaiChuanBaseForCausalLM
.
embedding_padding_modules
embed_padding_modules
=
BaiChuanBaseForCausalLM
.
embedding_padding_modules
expected_lora_
modules
:
list
[
str
]
=
[]
expected_lora_
lst
:
list
[
str
]
=
[]
for
module
in
BAICHUAN_LORA_MODULES
:
for
module
in
BAICHUAN_LORA_MODULES
:
if
module
in
packed_modules_mapping
:
if
module
in
packed_modules_mapping
:
expected_lora_
modules
.
extend
(
packed_modules_mapping
[
module
])
expected_lora_
lst
.
extend
(
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_
modules
.
append
(
module
)
expected_lora_
lst
.
append
(
module
)
expected_lora_modules
=
set
(
expected_lora_lst
)
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
orig_to_new_prefix
=
{
"model."
:
"language_model.model."
,
"model."
:
"language_model.model."
,
...
...
tests/lora/test_lora_huggingface.py
View file @
2f5f9acd
...
@@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
...
@@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
packed_modules_mapping
=
LlamaForCausalLM
.
packed_modules_mapping
packed_modules_mapping
=
LlamaForCausalLM
.
packed_modules_mapping
embedding_modules
=
LlamaForCausalLM
.
embedding_modules
embedding_modules
=
LlamaForCausalLM
.
embedding_modules
embed_padding_modules
=
LlamaForCausalLM
.
embedding_padding_modules
embed_padding_modules
=
LlamaForCausalLM
.
embedding_padding_modules
expected_lora_
modules
:
list
[
str
]
=
[]
expected_lora_
lst
:
list
[
str
]
=
[]
for
module
in
LLAMA_LORA_MODULES
:
for
module
in
LLAMA_LORA_MODULES
:
if
module
in
packed_modules_mapping
:
if
module
in
packed_modules_mapping
:
expected_lora_
modules
.
extend
(
packed_modules_mapping
[
module
])
expected_lora_
lst
.
extend
(
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_
modules
.
append
(
module
)
expected_lora_
lst
.
append
(
module
)
expected_lora_modules
=
set
(
expected_lora_lst
)
lora_path
=
get_adapter_absolute_path
(
lora_name
)
lora_path
=
get_adapter_absolute_path
(
lora_name
)
# lora loading should work for either absolute path and huggingface id.
# lora loading should work for either absolute path and huggingface id.
...
...
vllm/lora/layers/base.py
View file @
2f5f9acd
...
@@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
...
@@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
"""Returns True if the layer can be replaced by this LoRA layer."""
raise
NotImplementedError
raise
NotImplementedError
vllm/lora/layers/column_parallel_linear.py
View file @
2f5f9acd
...
@@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
ColumnParallelLinear
or
(
return
type
(
source_layer
)
is
ColumnParallelLinear
or
(
type
(
source_layer
)
is
MergedColumnParallelLinear
type
(
source_layer
)
is
MergedColumnParallelLinear
...
@@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
(
return
(
type
(
source_layer
)
is
MergedColumnParallelLinear
type
(
source_layer
)
is
MergedColumnParallelLinear
...
@@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
1
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
1
...
@@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
...
@@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
3
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
3
...
@@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
return
super
().
can_replace_layer
(
...
@@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
...
@@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
return
super
().
can_replace_layer
(
...
@@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
...
@@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
return
super
().
can_replace_layer
(
...
@@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
...
@@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
return
super
().
can_replace_layer
(
...
...
vllm/lora/layers/fused_moe.py
View file @
2f5f9acd
...
@@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
w13_lora_b_stacked
[
1
][
lora_id
][
experts_id
]
self
.
w13_lora_b_stacked
[
1
][
lora_id
][
experts_id
]
)
)
def
_slice_w13_a
(
self
,
w13_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank
=
w13_lora_a
.
shape
[
1
]
assert
current_lora_rank
%
self
.
tp_size
==
0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank
=
current_lora_rank
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_rank
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_rank
return
w13_lora_a
[:,
start_idx
:
end_idx
,
:]
def
_slice_w13_b
(
self
,
w13_lora_b
:
torch
.
Tensor
):
if
self
.
tp_size
==
1
:
return
w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
return
w13_lora_b
[:,
start_idx
:
end_idx
,
:]
def
_slice_w2_a
(
self
,
w2_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if
self
.
tp_size
==
1
:
return
w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
return
w2_lora_a
[:,
:,
start_idx
:
end_idx
]
def
_slice_w2_b
(
self
,
w2_lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size
=
w2_lora_b
.
shape
[
1
]
sliced_size
=
current_lora_size
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_size
return
w2_lora_b
[:,
start_idx
:
end_idx
,
:]
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
"""Resets the lora weights at index back to 0."""
"""Resets the lora weights at index back to 0."""
for
pos
in
range
(
self
.
_w13_slices
):
for
pos
in
range
(
self
.
_w13_slices
):
...
@@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
w2_lora_b_stacked
[
0
][
index
]
=
0
self
.
w2_lora_b_stacked
[
0
][
index
]
=
0
self
.
adapter_enabled
[
index
]
=
0
self
.
adapter_enabled
[
index
]
=
0
#
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
):
):
"""Overwrites lora tensors at index."""
"""Overwrites lora tensors at index."""
# Make mypy happy
assert
isinstance
(
lora_a
,
list
)
assert
isinstance
(
lora_a
,
list
)
assert
isinstance
(
lora_b
,
list
)
assert
isinstance
(
lora_b
,
list
)
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
self
.
adapter_enabled
[
index
]
=
1
self
.
adapter_enabled
[
index
]
=
1
for
eid
in
range
(
len
(
lora_a
)
//
3
):
w1_lora_a
=
lora_a
[
eid
*
3
]
w2_lora_a
=
lora_a
[
eid
*
3
+
1
]
w3_lora_a
=
lora_a
[
eid
*
3
+
2
]
w1_lora_b
=
lora_b
[
eid
*
3
]
w2_lora_b
=
lora_b
[
eid
*
3
+
1
]
w3_lora_b
=
lora_b
[
eid
*
3
+
2
]
# Handle the case of adding LoRA to only a subset of experts
if
w1_lora_a
is
None
or
w2_lora_a
is
None
or
w3_lora_a
is
None
:
continue
if
self
.
tp_size
>
1
:
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
w1_lora_b
=
w1_lora_b
[
start_idx
:
end_idx
,
:]
num_experts
=
self
.
w13_lora_a_stacked
[
0
].
shape
[
1
]
w3_lora_b
=
w3_lora_b
[
start_idx
:
end_idx
,
:]
w2_lora_a
=
w2_lora_a
[:,
start_idx
:
end_idx
]
w1_lora_a
,
w2_lora_a
,
w3_lora_a
=
lora_a
w1_lora_b
,
w2_lora_b
,
w3_lora_b
=
lora_b
if
self
.
fully_sharded
:
assert
(
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
num_experts
# and W2 B along the hidden_size dim.
==
w1_lora_a
.
shape
[
0
]
w13_shard_size
=
self
.
w13_lora_a_stacked
[
0
][
index
,
eid
].
shape
[
0
]
==
w2_lora_a
.
shape
[
0
]
w13_start_idx
=
self
.
tp_rank
*
w13_shard_size
==
w3_lora_a
.
shape
[
0
]
w13_end_idx
=
(
self
.
tp_rank
+
1
)
*
w13_shard_size
)
w1_lora_a
=
w1_lora_a
[
w13_start_idx
:
w13_end_idx
,
:]
w3_lora_a
=
w3_lora_a
[
w13_start_idx
:
w13_end_idx
,
:]
slliced_w1_lora_a
=
self
.
_slice_w13_a
(
w1_lora_a
)
slliced_w1_lora_b
=
self
.
_slice_w13_b
(
w1_lora_b
)
w2_shard_size
=
self
.
w2_lora_b_stacked
[
0
][
index
,
eid
].
shape
[
0
]
slliced_w3_lora_a
=
self
.
_slice_w13_a
(
w3_lora_a
)
w2_start_idx
=
self
.
tp_rank
*
w2_shard_size
slliced_w3_lora_b
=
self
.
_slice_w13_b
(
w3_lora_b
)
w2_end_idx
=
(
self
.
tp_rank
+
1
)
*
w2_shard_size
w2_lora_b
=
w2_lora_b
[
w2_start_idx
:
w2_end_idx
,
:]
sliced_w2_lora_a
=
self
.
_slice_w2_a
(
w2_lora_a
)
# w1 lora_a
sliced_w2_lora_b
=
self
.
_slice_w2_b
(
w2_lora_b
)
self
.
w13_lora_a_stacked
[
0
][
self
.
w13_lora_a_stacked
[
0
][
index
,
eid
,
:
w1_lora_a
.
shape
[
0
],
:
w1_lora_a
.
shape
[
1
]
index
,
:
,
:
slliced_
w1_lora_a
.
shape
[
1
],
:
slliced_
w1_lora_a
.
shape
[
2
]
].
copy_
(
w1_lora_a
,
non_blocking
=
True
)
].
copy_
(
slliced_
w1_lora_a
,
non_blocking
=
True
)
# w3 lora_a
self
.
w13_lora_a_stacked
[
1
][
self
.
w13_lora_a_stacked
[
1
][
index
,
eid
,
:
w3_lora_a
.
shape
[
0
],
:
w3_lora_a
.
shape
[
1
]
index
,
:
,
:
slliced_
w3_lora_a
.
shape
[
1
],
:
slliced_
w3_lora_a
.
shape
[
2
]
].
copy_
(
w3_lora_a
,
non_blocking
=
True
)
].
copy_
(
slliced_
w3_lora_a
,
non_blocking
=
True
)
# w1 lora_b
self
.
w13_lora_b_stacked
[
0
][
self
.
w13_lora_b_stacked
[
0
][
index
,
eid
,
:
w1_lora_b
.
shape
[
0
],
:
w1_lora_b
.
shape
[
1
]
index
,
:
,
:
slliced_
w1_lora_b
.
shape
[
1
],
:
slliced_
w1_lora_b
.
shape
[
2
]
].
copy_
(
w1_lora_b
,
non_blocking
=
True
)
].
copy_
(
slliced_
w1_lora_b
,
non_blocking
=
True
)
# w3 lora_b
self
.
w13_lora_b_stacked
[
1
][
self
.
w13_lora_b_stacked
[
1
][
index
,
eid
,
:
w3_lora_b
.
shape
[
0
],
:
w3_lora_b
.
shape
[
1
]
index
,
:
,
:
slliced_
w3_lora_b
.
shape
[
1
],
:
slliced_
w3_lora_b
.
shape
[
2
]
].
copy_
(
w3_lora_b
,
non_blocking
=
True
)
].
copy_
(
slliced_
w3_lora_b
,
non_blocking
=
True
)
self
.
w2_lora_a_stacked
[
0
][
self
.
w2_lora_a_stacked
[
0
][
index
,
eid
,
:
w2_lora_a
.
shape
[
0
],
:
w2_lora_a
.
shape
[
1
]
index
,
:
,
:
sliced_
w2_lora_a
.
shape
[
1
],
:
sliced_
w2_lora_a
.
shape
[
2
]
].
copy_
(
w2_lora_a
,
non_blocking
=
True
)
].
copy_
(
sliced_
w2_lora_a
,
non_blocking
=
True
)
self
.
w2_lora_b_stacked
[
0
][
self
.
w2_lora_b_stacked
[
0
][
index
,
eid
,
:
w2_lora_b
.
shape
[
0
],
:
w2_lora_b
.
shape
[
1
]
index
,
:
,
:
sliced_
w2_lora_b
.
shape
[
1
],
:
sliced_
w2_lora_b
.
shape
[
2
]
].
copy_
(
w2_lora_b
,
non_blocking
=
True
)
].
copy_
(
sliced_
w2_lora_b
,
non_blocking
=
True
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
forward
(
*
args
,
**
kwargs
)
return
self
.
base_layer
.
forward
(
*
args
,
**
kwargs
)
...
@@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return
type
(
source_layer
)
is
FusedMoE
and
len
(
packed_modules_list
)
==
2
# source_layer is FusedMoE or SharedFusedMoE
return
isinstance
(
source_layer
,
FusedMoE
)
and
len
(
packed_modules_list
)
==
2
class
FusedMoE3DWithLoRA
(
FusedMoEWithLoRA
):
class
FusedMoE3DWithLoRA
(
FusedMoEWithLoRA
):
...
@@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config
:
PretrainedConfig
|
None
=
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
None
:
)
->
None
:
"""Initializes lora matrices."""
"""Initializes lora matrices."""
assert
isinstance
(
model_config
,
PretrainedConfig
)
self
.
_base_model
=
model_config
.
architectures
[
0
]
self
.
max_loras
=
lora_config
.
max_loras
self
.
max_loras
=
lora_config
.
max_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
...
@@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
self
.
_create_lora_a_weights
(
max_loras
,
lora_config
)
self
.
_create_lora_a_weights
(
max_loras
,
lora_config
)
self
.
_create_lora_b_weights
(
max_loras
,
lora_config
)
self
.
_create_lora_b_weights
(
max_loras
,
lora_config
)
def
_slice_w13_a
(
self
,
w13_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_slice_w13_b
(
self
,
w13_lora_b
:
torch
.
Tensor
):
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank
=
w13_lora_a
.
shape
[
1
]
assert
current_lora_rank
%
self
.
tp_size
==
0
sliced_rank
=
current_lora_rank
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_rank
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_rank
return
w13_lora_a
[:,
start_idx
:
end_idx
,
:]
def
_slice_w13_b
(
self
,
w13_lora_b
:
torch
.
Tensor
,
is_interleave
:
bool
=
True
):
if
self
.
tp_size
==
1
:
if
self
.
tp_size
==
1
:
return
w13_lora_b
return
w13_lora_b
...
@@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
if
is_interleave
:
# HACK: Currently, only GPT-OSS is in interleaved order
if
self
.
_base_model
==
"GptOssForCausalLM"
:
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b
=
w13_lora_b
[:,
::
2
,
:]
w1_lora_b
=
w13_lora_b
[:,
::
2
,
:]
...
@@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
return
torch
.
cat
([
sliced_w1_lora_b
,
sliced_w3_lora_b
],
dim
=
1
)
return
torch
.
cat
([
sliced_w1_lora_b
,
sliced_w3_lora_b
],
dim
=
1
)
def
_slice_w2_a
(
self
,
w2_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
tp_size
==
1
:
return
w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
return
w2_lora_a
[:,
:,
start_idx
:
end_idx
]
def
_slice_w2_b
(
self
,
w2_lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size
=
w2_lora_b
.
shape
[
1
]
sliced_size
=
current_lora_size
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_size
return
w2_lora_b
[:,
start_idx
:
end_idx
,
:]
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
w2_lora_b
=
w2_lora_b
.
permute
(
1
,
0
,
2
)
w2_lora_b
=
w2_lora_b
.
permute
(
1
,
0
,
2
)
sliced_w13_lora_a
=
self
.
_slice_w13_a
(
w13_lora_a
)
sliced_w13_lora_a
=
self
.
_slice_w13_a
(
w13_lora_a
)
sliced_w13_lora_b
=
self
.
_slice_w13_b
(
w13_lora_b
,
is_interleave
=
True
)
sliced_w13_lora_b
=
self
.
_slice_w13_b
(
w13_lora_b
)
sliced_w2_lora_a
=
self
.
_slice_w2_a
(
w2_lora_a
)
sliced_w2_lora_a
=
self
.
_slice_w2_a
(
w2_lora_a
)
sliced_w2_lora_b
=
self
.
_slice_w2_b
(
w2_lora_b
)
sliced_w2_lora_b
=
self
.
_slice_w2_b
(
w2_lora_b
)
...
@@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
...
@@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
return
typ
e
(
source_layer
)
is
FusedMoE
and
len
(
packed_modules_list
)
==
1
return
isinstanc
e
(
source_layer
,
FusedMoE
)
and
len
(
packed_modules_list
)
==
1
vllm/lora/layers/logits_processor.py
View file @
2f5f9acd
...
@@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# Special handling for the LogitsProcessor.
# Special handling for the LogitsProcessor.
return
False
return
False
vllm/lora/layers/replicated_linear.py
View file @
2f5f9acd
...
@@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
ReplicatedLinear
return
type
(
source_layer
)
is
ReplicatedLinear
...
...
vllm/lora/layers/row_parallel_linear.py
View file @
2f5f9acd
...
@@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
RowParallelLinear
return
type
(
source_layer
)
is
RowParallelLinear
...
@@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
...
@@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
return
super
().
can_replace_layer
(
...
...
vllm/lora/layers/vocal_parallel_embedding.py
View file @
2f5f9acd
...
@@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
source_layer
:
nn
.
Module
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
return
type
(
source_layer
)
is
VocabParallelEmbedding
return
type
(
source_layer
)
is
VocabParallelEmbedding
...
...
vllm/lora/lora_weights.py
View file @
2f5f9acd
...
@@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
...
@@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
)
)
return
obj
return
obj
@
classmethod
def
pack_moe
(
cls
,
loras
:
GenericSequence
[
Optional
[
"LoRALayerWeights"
]],
module_name
:
str
)
->
"PackedLoRALayerWeights"
:
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora
=
next
(
lora
for
lora
in
loras
if
lora
is
not
None
)
assert
first_lora
is
not
None
rank
=
first_lora
.
rank
lora_alpha
=
first_lora
.
lora_alpha
assert
len
(
loras
)
%
3
==
0
w1_lora_a_lst
=
[]
w2_lora_a_lst
=
[]
w3_lora_a_lst
=
[]
w1_lora_b_lst
=
[]
w2_lora_b_lst
=
[]
w3_lora_b_lst
=
[]
# TODO: Consider the case where some experts don't have LoRA added.
for
eid
in
range
(
len
(
loras
)
//
3
):
w1_lora
=
loras
[
eid
*
3
]
w2_lora
=
loras
[
eid
*
3
+
1
]
w3_lora
=
loras
[
eid
*
3
+
2
]
assert
w1_lora
is
not
None
assert
w2_lora
is
not
None
assert
w3_lora
is
not
None
w1_lora_a_lst
.
append
(
w1_lora
.
lora_a
)
w2_lora_a_lst
.
append
(
w2_lora
.
lora_a
)
w3_lora_a_lst
.
append
(
w3_lora
.
lora_a
)
w1_lora_b_lst
.
append
(
w1_lora
.
lora_b
)
w2_lora_b_lst
.
append
(
w2_lora
.
lora_b
)
w3_lora_b_lst
.
append
(
w3_lora
.
lora_b
)
w1_lora_a
=
torch
.
stack
(
w1_lora_a_lst
,
dim
=
0
)
# (num_experts,rank,input_size)
w2_lora_a
=
torch
.
stack
(
w2_lora_a_lst
,
dim
=
0
)
w3_lora_a
=
torch
.
stack
(
w3_lora_a_lst
,
dim
=
0
)
w1_lora_b
=
torch
.
stack
(
w1_lora_b_lst
,
dim
=
0
)
# (num_experts,output_size,rank)
w2_lora_b
=
torch
.
stack
(
w2_lora_b_lst
,
dim
=
0
)
w3_lora_b
=
torch
.
stack
(
w3_lora_b_lst
,
dim
=
0
)
obj
=
cls
(
module_name
,
rank
,
[
lora_alpha
,
lora_alpha
,
lora_alpha
],
[
w1_lora_a
,
w2_lora_a
,
w3_lora_a
],
[
w1_lora_b
,
w2_lora_b
,
w3_lora_b
],
)
return
obj
def
optimize
(
self
)
->
"PackedLoRALayerWeights"
:
def
optimize
(
self
)
->
"PackedLoRALayerWeights"
:
"""Optimize the LoRA by merging the scaling into lora_b."""
"""Optimize the LoRA by merging the scaling into lora_b."""
for
i
in
range
(
len
(
self
.
lora_b
)):
for
i
in
range
(
len
(
self
.
lora_b
)):
...
...
vllm/lora/models.py
View file @
2f5f9acd
...
@@ -13,7 +13,7 @@ from torch import nn
...
@@ -13,7 +13,7 @@ from torch import nn
from
vllm.config.lora
import
LoRAConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
FusedMoEWithLoRA
,
LoRAMapping
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
FusedMoE
3D
WithLoRA
,
LoRAMapping
from
vllm.lora.lora_weights
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora_weights
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.peft_helper
import
PEFTHelper
from
vllm.lora.peft_helper
import
PEFTHelper
from
vllm.lora.punica_wrapper
import
get_punica_wrapper
from
vllm.lora.punica_wrapper
import
get_punica_wrapper
...
@@ -151,16 +151,13 @@ class LoRAModel:
...
@@ -151,16 +151,13 @@ class LoRAModel:
if
pin_memory
:
if
pin_memory
:
loras
[
module_name
].
lora_b
=
loras
[
module_name
].
lora_b
.
pin_memory
()
loras
[
module_name
].
lora_b
=
loras
[
module_name
].
lora_b
.
pin_memory
()
for
lora
in
loras
.
values
():
lora
.
optimize
()
return
cls
(
lora_model_id
,
peft_helper
.
r
,
loras
)
return
cls
(
lora_model_id
,
peft_helper
.
r
,
loras
)
@
classmethod
@
classmethod
def
from_local_checkpoint
(
def
from_local_checkpoint
(
cls
,
cls
,
lora_dir
:
str
,
lora_dir
:
str
,
expected_lora_modules
:
li
st
[
str
],
expected_lora_modules
:
s
e
t
[
str
],
peft_helper
:
PEFTHelper
,
peft_helper
:
PEFTHelper
,
*
,
*
,
lora_model_id
:
int
|
None
=
None
,
lora_model_id
:
int
|
None
=
None
,
...
@@ -190,10 +187,7 @@ class LoRAModel:
...
@@ -190,10 +187,7 @@ class LoRAModel:
lora_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.safetensors"
)
lora_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.safetensors"
)
lora_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.bin"
)
lora_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.bin"
)
lora_pt_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.pt"
)
lora_pt_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.pt"
)
# new_embeddings_tensor_path = os.path.join(
# lora_dir, "new_embeddings.safetensors"
# )
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
unexpected_modules
:
list
[
list
[
str
]
|
str
]
=
[]
unexpected_modules
:
list
[
list
[
str
]
|
str
]
=
[]
...
@@ -201,18 +195,19 @@ class LoRAModel:
...
@@ -201,18 +195,19 @@ class LoRAModel:
for
lora_module
in
modules
.
keys
():
# noqa
for
lora_module
in
modules
.
keys
():
# noqa
if
is_base_embeddding_weights
(
lora_module
):
if
is_base_embeddding_weights
(
lora_module
):
continue
continue
module_name
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
# Handle PEFT file format where experts.base_layer is the
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
# gate_up_proj and experts is the down_proj
if
"base_layer"
in
lora_module
:
if
"base_layer"
in
lora_module
:
continue
continue
module_name
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
# Case for expert lora weights
# Case for expert lora weights
if
".experts"
in
module_name
:
if
".experts"
in
module_name
:
if
not
any
(
expert_idx
=
module_name
.
find
(
".experts"
)
module_name
.
endswith
(
ele
)
for
ele
in
expected_lora_modules
expert_suffix
=
module_name
[
expert_idx
+
1
:]
)
:
if
expert_suffix
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
unexpected_modules
.
append
(
module_name
)
elif
module_name
.
split
(
"."
)[
-
1
]
not
in
expected_lora_modules
:
elif
module_name
.
rsplit
(
"."
,
1
)[
-
1
]
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
unexpected_modules
.
append
(
module_name
)
if
unexpected_modules
:
if
unexpected_modules
:
...
@@ -358,9 +353,7 @@ class LoRAModelManager:
...
@@ -358,9 +353,7 @@ class LoRAModelManager:
self
.
modules
:
dict
[
str
,
BaseLayerWithLoRA
]
=
{}
self
.
modules
:
dict
[
str
,
BaseLayerWithLoRA
]
=
{}
# Dict instead of a set for compatibility with LRUCache.
# Dict instead of a set for compatibility with LRUCache.
self
.
_last_mapping
:
LoRAMapping
|
None
=
None
self
.
_last_mapping
:
LoRAMapping
|
None
=
None
self
.
_is_3d_moe_model
=
is_moe_model
(
self
.
model
)
and
hasattr
(
self
.
_is_3d_moe_model
=
is_moe_model
(
self
.
model
)
and
self
.
model
.
is_3d_moe_weight
self
.
model
,
"is_3d_moe_weight"
)
self
.
_create_lora_modules
()
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
model
.
lora_manager
=
self
...
@@ -411,7 +404,7 @@ class LoRAModelManager:
...
@@ -411,7 +404,7 @@ class LoRAModelManager:
continue
continue
# Note (gnovack) - If MOE lora weights are not split into
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
# num_experts chunks, we split them here
if
isinstance
(
module
,
FusedMoEWithLoRA
)
and
torch
.
is_tensor
(
if
isinstance
(
module
,
FusedMoE
3D
WithLoRA
)
and
torch
.
is_tensor
(
module_lora
.
lora_a
module_lora
.
lora_a
):
):
# Handle PEFT file format where experts.base_layer is the
# Handle PEFT file format where experts.base_layer is the
...
@@ -679,6 +672,9 @@ class LoRAModelManager:
...
@@ -679,6 +672,9 @@ class LoRAModelManager:
"cpu"
,
"cpu"
,
)
)
subloras
.
append
(
lora
)
subloras
.
append
(
lora
)
if
module
.
__class__
.
__name__
==
"FusedMoEWithLoRA"
:
lora
=
PackedLoRALayerWeights
.
pack_moe
(
subloras
,
module_name
)
else
:
lora
=
PackedLoRALayerWeights
.
pack
(
subloras
)
lora
=
PackedLoRALayerWeights
.
pack
(
subloras
)
model
.
loras
[
module_name
]
=
lora
model
.
loras
[
module_name
]
=
lora
return
model
return
model
...
@@ -739,6 +735,11 @@ class LoRAModelManager:
...
@@ -739,6 +735,11 @@ class LoRAModelManager:
replaced_module_name
=
module_name
.
replace
(
"model."
,
""
)
replaced_module_name
=
module_name
.
replace
(
"model."
,
""
)
if
lora_model
.
check_lora_name
(
module_name
):
if
lora_model
.
check_lora_name
(
module_name
):
module_name
=
replaced_module_name
module_name
=
replaced_module_name
if
module_name
.
endswith
(
".experts"
):
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack_moe
(
replacement_loras
,
module_name
)
else
:
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
replacement_loras
replacement_loras
)
)
...
@@ -746,6 +747,9 @@ class LoRAModelManager:
...
@@ -746,6 +747,9 @@ class LoRAModelManager:
for
module
in
replaced_module
:
for
module
in
replaced_module
:
lora_model
.
loras
.
pop
(
module
,
None
)
lora_model
.
loras
.
pop
(
module
,
None
)
for
lora
in
lora_model
.
loras
.
values
():
lora
.
optimize
()
def
_get_lora_layer_weights
(
def
_get_lora_layer_weights
(
self
,
lora_model
:
LoRAModel
,
module_name
:
str
self
,
lora_model
:
LoRAModel
,
module_name
:
str
)
->
LoRALayerWeights
|
None
:
)
->
LoRALayerWeights
|
None
:
...
...
vllm/lora/utils.py
View file @
2f5f9acd
...
@@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
...
@@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
def
is_base_embeddding_weights
(
name
:
str
)
->
bool
:
def
is_base_embeddding_weights
(
name
:
str
)
->
bool
:
# hardcoded subfixes for input & output embedding weights
# hardcoded subfixes for input & output embedding weights
input_embedding_subfix
=
".embed_tokens.base_layer.weight"
embedding_suffixes
=
(
output_embedding_subfix
=
".lm_head.base_layer.weight"
".embed_tokens.base_layer.weight"
,
".lm_head.base_layer.weight"
,
return
name
.
endswith
(
input_embedding_subfix
)
or
name
.
endswith
(
output_embedding_subfix
)
)
return
name
.
endswith
(
embedding_suffixes
)
def
is_regex_target_modules
(
def
is_regex_target_modules
(
load_modules
:
str
|
list
[
str
],
expected_lora_modules
:
li
st
[
str
]
load_modules
:
str
|
list
[
str
],
expected_lora_modules
:
s
e
t
[
str
]
)
->
bool
:
)
->
bool
:
"""
"""
PEFT supports passing `target_modules` in the form of regular expressions,
PEFT supports passing `target_modules` in the form of regular expressions,
...
@@ -195,8 +194,8 @@ def is_regex_target_modules(
...
@@ -195,8 +194,8 @@ def is_regex_target_modules(
except
re
.
error
:
except
re
.
error
:
return
False
return
False
def
is_subset
(
sub_list
,
full_
li
st
):
def
is_subset
(
sub_list
,
full_s
e
t
):
return
set
(
sub_list
).
issubset
(
set
(
full_
li
st
)
)
return
set
(
sub_list
).
issubset
(
full_s
e
t
)
# Similar to PEFT's processing logic, regex-related operations are only
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
# executed when the load_modules is a `str`.
...
@@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
...
@@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# the expert indices are expanded based on the configured number
# of routed experts.
# of routed experts.
packed_modules_mapping
=
get_packed_modules_mapping
(
model
)
packed_modules_mapping
=
get_packed_modules_mapping
(
model
)
if
not
hasattr
(
model
,
"
is_3d_moe_weight
"
)
:
if
not
model
.
is_3d_moe_weight
:
# 3D MoE LoRA does not need `packed_modules_mapping`
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping
[
"experts"
]
=
[
packed_modules_mapping
[
"experts"
]
=
[
weight_name
.
rstrip
(
"."
)
weight_name
.
rstrip
(
"."
)
...
...
vllm/lora/worker_manager.py
View file @
2f5f9acd
...
@@ -88,15 +88,15 @@ class WorkerLoRAManager:
...
@@ -88,15 +88,15 @@ class WorkerLoRAManager:
try
:
try
:
supported_lora_modules
=
self
.
_adapter_manager
.
supported_lora_modules
supported_lora_modules
=
self
.
_adapter_manager
.
supported_lora_modules
packed_modules_mapping
=
self
.
_adapter_manager
.
packed_modules_mapping
packed_modules_mapping
=
self
.
_adapter_manager
.
packed_modules_mapping
expected_lora_
modules
:
list
[
str
]
=
[]
expected_lora_
lst
:
list
[
str
]
=
[]
for
module
in
supported_lora_modules
:
for
module
in
supported_lora_modules
:
if
module
in
packed_modules_mapping
:
if
module
in
packed_modules_mapping
:
expected_lora_
modules
.
extend
(
packed_modules_mapping
[
module
])
expected_lora_
lst
.
extend
(
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_
modules
.
append
(
module
)
expected_lora_
lst
.
append
(
module
)
if
module
==
"experts"
:
if
module
==
"experts"
:
expected_lora_
modules
.
append
(
module
)
expected_lora_
lst
.
append
(
module
)
expected_lora_modules
=
list
(
set
(
expected_lora_
modules
)
)
expected_lora_modules
=
set
(
expected_lora_
lst
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
peft_helper
=
PEFTHelper
.
from_local_dir
(
peft_helper
=
PEFTHelper
.
from_local_dir
(
...
...
vllm/model_executor/models/interfaces.py
View file @
2f5f9acd
...
@@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
...
@@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the
There is no need to redefine this flag if this class is in the
MRO of your model class.
MRO of your model class.
"""
"""
is_3d_moe_weight
:
ClassVar
[
bool
]
=
False
# The `embedding_module` and `embedding_padding_modules`
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
# are empty by default.
embedding_modules
:
ClassVar
[
dict
[
str
,
str
]]
=
{}
embedding_modules
:
ClassVar
[
dict
[
str
,
str
]]
=
{}
...
...
vllm/model_executor/models/qwen3_vl_moe.py
View file @
2f5f9acd
...
@@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
...
@@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
class
Qwen3VLMoeForConditionalGeneration
(
class
Qwen3VLMoeForConditionalGeneration
(
Qwen3VLForConditionalGeneration
,
Qwen3VLMoeMixtureOfExperts
Qwen3VLForConditionalGeneration
,
Qwen3VLMoeMixtureOfExperts
):
):
is_3d_moe_weight
:
bool
=
True
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment