Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1073ba68
Unverified
Commit
1073ba68
authored
Nov 24, 2025
by
Jee Jee Li
Committed by
GitHub
Nov 24, 2025
Browse files
[LoRA] Optimize 3D MoE logic (#29222)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
c309bb52
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
395 additions
and
103 deletions
+395
-103
tests/lora/test_gptoss_tp.py
tests/lora/test_gptoss_tp.py
+6
-1
vllm/lora/layers/__init__.py
vllm/lora/layers/__init__.py
+2
-1
vllm/lora/layers/base.py
vllm/lora/layers/base.py
+2
-2
vllm/lora/layers/base_linear.py
vllm/lora/layers/base_linear.py
+4
-2
vllm/lora/layers/column_parallel_linear.py
vllm/lora/layers/column_parallel_linear.py
+2
-2
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+286
-63
vllm/lora/layers/logits_processor.py
vllm/lora/layers/logits_processor.py
+4
-2
vllm/lora/layers/vocal_parallel_embedding.py
vllm/lora/layers/vocal_parallel_embedding.py
+5
-2
vllm/lora/models.py
vllm/lora/models.py
+75
-24
vllm/lora/utils.py
vllm/lora/utils.py
+8
-4
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+1
-0
No files found.
tests/lora/test_gptoss_tp.py
View file @
1073ba68
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
vllm
from
vllm.lora.request
import
LoRARequest
...
...
@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_gpt_oss_lora_tp2
(
gptoss20b_lora_files
):
@
pytest
.
mark
.
parametrize
(
"fully_sharded_loras"
,
[
False
,
True
])
def
test_gpt_oss_lora_tp2
(
gptoss20b_lora_files
,
fully_sharded_loras
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
enable_lora
=
True
,
max_loras
=
2
,
max_lora_rank
=
8
,
max_num_seqs
=
16
,
tensor_parallel_size
=
2
,
fully_sharded_loras
=
fully_sharded_loras
,
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
cudagraph_specialize_lora
=
False
,
),
...
...
vllm/lora/layers/__init__.py
View file @
1073ba68
...
...
@@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA
,
QKVParallelLinearWithShardedLoRA
,
)
from
vllm.lora.layers.fused_moe
import
FusedMoEWithLoRA
from
vllm.lora.layers.fused_moe
import
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
from
vllm.lora.layers.logits_processor
import
LogitsProcessorWithLoRA
from
vllm.lora.layers.replicated_linear
import
ReplicatedLinearWithLoRA
from
vllm.lora.layers.row_parallel_linear
import
(
...
...
@@ -38,4 +38,5 @@ __all__ = [
"ReplicatedLinearWithLoRA"
,
"LoRAMapping"
,
"FusedMoEWithLoRA"
,
"FusedMoE3DWithLoRA"
,
]
vllm/lora/layers/base.py
View file @
1073ba68
...
...
@@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
"""Overwrites lora tensors at index."""
...
...
...
vllm/lora/layers/base_linear.py
View file @
1073ba68
...
...
@@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert
isinstance
(
lora_a
,
torch
.
Tensor
)
assert
isinstance
(
lora_b
,
torch
.
Tensor
)
assert
(
len
(
self
.
lora_a_stacked
)
==
len
(
self
.
lora_b_stacked
)
==
self
.
n_slices
==
1
)
...
...
vllm/lora/layers/column_parallel_linear.py
View file @
1073ba68
...
...
@@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
self
.
reset_lora
(
index
)
...
...
vllm/lora/layers/fused_moe.py
View file @
1073ba68
...
...
@@ -42,7 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
device
=
base_layer
.
w2_weight
.
device
self
.
w13_slices
=
2
self
.
_
w13_slices
=
2
self
.
_inject_lora_into_fused_moe
()
def
_normalize_keys
(
self
,
config
:
dict
[
str
,
int
|
None
])
->
dict
[
str
,
int
|
None
]:
...
...
@@ -160,7 +160,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
op_prefix
=
"w13"
,
num_loras
=
self
.
max_loras
,
rank
=
max_lora_rank
,
num_slices
=
self
.
w13_slices
,
num_slices
=
self
.
_
w13_slices
,
M
=
M
,
layer
=
layer
,
top_k
=
top_k
,
...
...
@@ -230,7 +230,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens
=
hidden_states
.
size
(
0
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
max_lora_rank
=
self
.
w2_lora_a_stacked
.
shape
[
-
2
]
max_lora_rank
=
self
.
w2_lora_a_stacked
[
0
]
.
shape
[
-
2
]
shrink_config
,
expand_config
=
self
.
_get_lora_moe_configs
(
op_prefix
=
"w2"
,
num_loras
=
self
.
max_loras
,
...
...
@@ -258,8 +258,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
punica_wrapper
.
add_lora_fused_moe
(
intermediate_cache3
,
intermediate_cache2
,
(
self
.
w2_lora_a_stacked
,
),
(
self
.
w2_lora_b_stacked
,
),
self
.
w2_lora_a_stacked
,
self
.
w2_lora_b_stacked
,
topk_weights
,
sorted_token_ids_lora
,
expert_ids_lora
,
...
...
@@ -292,22 +292,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
base_layer
.
quant_method
,
m_fused_moe_fn
)
def
create_lora_weights
(
def
_
create_lora_
a_
weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
None
:
"""Initializes lora matrices."""
assert
self
.
w13_slices
==
2
self
.
max_loras
=
lora_config
.
max_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
self
.
adapter_enabled
=
torch
.
tensor
(
[
0
]
*
(
max_loras
+
1
),
dtype
=
torch
.
int
,
device
=
self
.
device
)
self
.
w13_lora_a_stacked
=
tuple
(
):
self
.
w13_lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...]
=
tuple
(
torch
.
zeros
(
(
max_loras
,
...
...
@@ -320,34 +310,37 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
w13_slices
)
for
_
in
range
(
self
.
_
w13_slices
)
)
self
.
w13_lora_b_stacked
=
tuple
(
self
.
w2_lora_a_stacked
:
tuple
[
torch
.
Tensor
,
...]
=
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
intermediate_size_per_partition
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
intermediate_size_per_partition
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
w13_slices
)
),
)
self
.
w2_lora_a_stacked
=
torch
.
zeros
(
def
_create_lora_b_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
):
self
.
w13_lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...]
=
tuple
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
intermediate_size_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
w2_lora_b_stacked
=
torch
.
zeros
(
for
_
in
range
(
self
.
_w13_slices
)
)
self
.
w2_lora_b_stacked
:
tuple
[
torch
.
Tensor
,
...]
=
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
...
...
@@ -358,10 +351,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
)
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
None
:
"""Initializes lora matrices."""
self
.
max_loras
=
lora_config
.
max_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
self
.
adapter_enabled
=
torch
.
tensor
(
[
0
]
*
(
max_loras
+
1
),
dtype
=
torch
.
int
,
device
=
self
.
device
)
self
.
_create_lora_a_weights
(
max_loras
,
lora_config
)
self
.
_create_lora_b_weights
(
max_loras
,
lora_config
)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
# TODO Optimize this section
self
.
lora_a_stacked
=
[]
self
.
lora_b_stacked
=
[]
for
lora_id
in
range
(
max_loras
):
...
...
@@ -370,36 +381,43 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_a_stacked
.
append
(
self
.
w2_lora_a_stacked
[
lora_id
][
experts_id
])
self
.
lora_a_stacked
.
append
(
self
.
w
13
_lora_a_stacked
[
1
][
lora_id
][
experts_id
]
self
.
w
2
_lora_a_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_b_stacked
.
append
(
self
.
w13_lora_b_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_b_stacked
.
append
(
self
.
w2_lora_b_stacked
[
lora_id
][
experts_id
])
self
.
lora_b_stacked
.
append
(
self
.
w2_lora_b_stacked
[
0
][
lora_id
][
experts_id
]
)
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
1
][
lora_id
][
experts_id
]
)
self
.
lora_b_stacked
.
append
(
self
.
w13_lora_b_stacked
[
1
][
lora_id
][
experts_id
]
)
def
reset_lora
(
self
,
index
:
int
):
"""Resets the lora weights at index back to 0."""
for
pos
in
range
(
self
.
w13_slices
):
for
pos
in
range
(
self
.
_
w13_slices
):
self
.
w13_lora_a_stacked
[
pos
][
index
]
=
0
self
.
w13_lora_b_stacked
[
pos
][
index
]
=
0
self
.
w2_lora_a_stacked
[
index
]
=
0
self
.
w2_lora_b_stacked
[
index
]
=
0
self
.
w2_lora_a_stacked
[
0
][
index
]
=
0
self
.
w2_lora_b_stacked
[
0
][
index
]
=
0
self
.
adapter_enabled
[
index
]
=
0
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
"""Overwrites lora tensors at index."""
assert
isinstance
(
lora_a
,
list
)
assert
isinstance
(
lora_b
,
list
)
self
.
reset_lora
(
index
)
self
.
adapter_enabled
[
index
]
=
1
for
eid
in
range
(
len
(
lora_a
)
//
3
):
...
...
@@ -432,7 +450,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w1_lora_a
=
w1_lora_a
[
w13_start_idx
:
w13_end_idx
,
:]
w3_lora_a
=
w3_lora_a
[
w13_start_idx
:
w13_end_idx
,
:]
w2_shard_size
=
self
.
w2_lora_b_stacked
[
index
,
eid
].
shape
[
0
]
w2_shard_size
=
self
.
w2_lora_b_stacked
[
0
][
index
,
eid
].
shape
[
0
]
w2_start_idx
=
self
.
tp_rank
*
w2_shard_size
w2_end_idx
=
(
self
.
tp_rank
+
1
)
*
w2_shard_size
w2_lora_b
=
w2_lora_b
[
w2_start_idx
:
w2_end_idx
,
:]
...
...
@@ -454,14 +472,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index
,
eid
,
:
w3_lora_b
.
shape
[
0
],
:
w3_lora_b
.
shape
[
1
]
].
copy_
(
w3_lora_b
,
non_blocking
=
True
)
self
.
w2_lora_a_stacked
[
self
.
w2_lora_a_stacked
[
0
][
index
,
eid
,
:
w2_lora_a
.
shape
[
0
],
:
w2_lora_a
.
shape
[
1
]
].
copy_
(
w2_lora_a
,
non_blocking
=
True
)
self
.
w2_lora_b_stacked
[
self
.
w2_lora_b_stacked
[
0
][
index
,
eid
,
:
w2_lora_b
.
shape
[
0
],
:
w2_lora_b
.
shape
[
1
]
].
copy_
(
w2_lora_b
,
non_blocking
=
True
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
forward
(
*
args
,
**
kwargs
)
def
maybe_all_reduce_tensor_model_parallel
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
maybe_all_reduce_tensor_model_parallel
(
*
args
,
**
kwargs
)
@
property
def
_shared_experts
(
self
):
return
self
.
base_layer
.
_shared_experts
@
property
def
quant_method
(
self
):
return
self
.
base_layer
.
quant_method
@
property
def
is_internal_router
(
self
)
->
bool
:
return
self
.
base_layer
.
is_internal_router
@
classmethod
def
can_replace_layer
(
cls
,
...
...
@@ -472,22 +508,209 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return
isinstance
(
source_layer
,
FusedMoE
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
forward
(
*
args
,
**
kwargs
)
return
type
(
source_layer
)
is
FusedMoE
and
len
(
packed_modules_list
)
==
2
def
maybe_all_reduce_tensor_model_parallel
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
maybe_all_reduce_tensor_model_parallel
(
*
args
,
**
kwargs
)
class
FusedMoE3DWithLoRA
(
FusedMoEWithLoRA
):
def
__init__
(
self
,
base_layer
):
super
().
__init__
(
base_layer
)
self
.
_w13_slices
=
1
def
_create_lora_b_weights
(
self
,
max_loras
,
lora_config
):
self
.
w13_lora_b_stacked
:
tuple
[
torch
.
Tensor
]
=
tuple
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
intermediate_size_per_partition
*
2
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
_w13_slices
)
)
self
.
w2_lora_b_stacked
:
tuple
[
torch
.
Tensor
]
=
(
torch
.
zeros
(
(
max_loras
,
self
.
base_layer
.
local_num_experts
,
self
.
base_layer
.
hidden_size
if
not
self
.
fully_sharded
else
divide
(
self
.
base_layer
.
hidden_size
,
self
.
tp_size
),
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
),
)
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
None
:
"""Initializes lora matrices."""
self
.
max_loras
=
lora_config
.
max_loras
self
.
fully_sharded
=
lora_config
.
fully_sharded_loras
self
.
adapter_enabled
=
torch
.
tensor
(
[
0
]
*
(
max_loras
+
1
),
dtype
=
torch
.
int
,
device
=
self
.
device
)
self
.
_create_lora_a_weights
(
max_loras
,
lora_config
)
self
.
_create_lora_b_weights
(
max_loras
,
lora_config
)
def
_slice_w13_a
(
self
,
w13_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank
=
w13_lora_a
.
shape
[
1
]
assert
current_lora_rank
%
self
.
tp_size
==
0
sliced_rank
=
current_lora_rank
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_rank
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_rank
return
w13_lora_a
[:,
start_idx
:
end_idx
,
:]
def
_slice_w13_b
(
self
,
w13_lora_b
:
torch
.
Tensor
,
is_interleave
:
bool
=
True
):
if
self
.
tp_size
==
1
:
return
w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
if
is_interleave
:
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b
=
w13_lora_b
[:,
::
2
,
:]
w3_lora_b
=
w13_lora_b
[:,
1
::
2
,
:]
sliced_w1_lora_b
=
w1_lora_b
[:,
start_idx
:
end_idx
,
:]
sliced_w3_lora_b
=
w3_lora_b
[:,
start_idx
:
end_idx
,
:]
return
torch
.
stack
([
sliced_w1_lora_b
,
sliced_w3_lora_b
],
dim
=
2
).
flatten
(
1
,
2
)
else
:
slice_size
=
w13_lora_b
.
shape
[
1
]
//
2
w1_lora_b
=
w13_lora_b
[:,
:
slice_size
,
:]
w3_lora_b
=
w13_lora_b
[:,
slice_size
:,
:]
sliced_w1_lora_b
=
w1_lora_b
[:,
start_idx
:
end_idx
,
:]
sliced_w3_lora_b
=
w3_lora_b
[:,
start_idx
:
end_idx
,
:]
return
torch
.
cat
([
sliced_w1_lora_b
,
sliced_w3_lora_b
],
dim
=
1
)
def
_slice_w2_a
(
self
,
w2_lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
tp_size
==
1
:
return
w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size
=
self
.
base_layer
.
intermediate_size_per_partition
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
return
w2_lora_a
[:,
:,
start_idx
:
end_idx
]
def
_slice_w2_b
(
self
,
w2_lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
tp_size
==
1
or
not
self
.
fully_sharded
:
return
w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size
=
w2_lora_b
.
shape
[
1
]
sliced_size
=
current_lora_size
//
self
.
tp_size
start_idx
=
self
.
tp_rank
*
sliced_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
sliced_size
return
w2_lora_b
[:,
start_idx
:
end_idx
,
:]
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert
isinstance
(
lora_a
,
list
)
assert
isinstance
(
lora_b
,
list
)
assert
len
(
lora_a
)
==
len
(
lora_b
)
==
2
self
.
reset_lora
(
index
)
self
.
adapter_enabled
[
index
]
=
1
num_experts
=
self
.
w13_lora_a_stacked
[
0
].
shape
[
1
]
w13_lora_a
,
w2_lora_a
=
lora_a
w13_lora_b
,
w2_lora_b
=
lora_b
# (num_experts,rank,input_size)
w13_lora_a
=
w13_lora_a
.
reshape
(
num_experts
,
-
1
,
w13_lora_a
.
shape
[
-
1
])
w2_lora_a
=
w2_lora_a
.
reshape
(
num_experts
,
-
1
,
w2_lora_a
.
shape
[
-
1
])
# (output_size,num_experts,rank)
w13_lora_b
=
w13_lora_b
.
reshape
(
w13_lora_b
.
shape
[
0
],
num_experts
,
-
1
)
w2_lora_b
=
w2_lora_b
.
reshape
(
w2_lora_b
.
shape
[
0
],
num_experts
,
-
1
)
# (num_experts,output_size,rank)
w13_lora_b
=
w13_lora_b
.
permute
(
1
,
0
,
2
)
w2_lora_b
=
w2_lora_b
.
permute
(
1
,
0
,
2
)
sliced_w13_lora_a
=
self
.
_slice_w13_a
(
w13_lora_a
)
sliced_w13_lora_b
=
self
.
_slice_w13_b
(
w13_lora_b
,
is_interleave
=
True
)
sliced_w2_lora_a
=
self
.
_slice_w2_a
(
w2_lora_a
)
sliced_w2_lora_b
=
self
.
_slice_w2_b
(
w2_lora_b
)
self
.
w13_lora_a_stacked
[
0
][
index
,
:,
:
sliced_w13_lora_a
.
shape
[
1
],
:
sliced_w13_lora_a
.
shape
[
2
]
].
copy_
(
sliced_w13_lora_a
,
non_blocking
=
True
)
self
.
w2_lora_a_stacked
[
0
][
index
,
:,
:
sliced_w2_lora_a
.
shape
[
1
],
:
sliced_w2_lora_a
.
shape
[
2
]
].
copy_
(
sliced_w2_lora_a
,
non_blocking
=
True
)
self
.
w13_lora_b_stacked
[
0
][
index
,
:,
:
sliced_w13_lora_b
.
shape
[
1
],
:
sliced_w13_lora_b
.
shape
[
2
]
].
copy_
(
sliced_w13_lora_b
,
non_blocking
=
True
)
self
.
w2_lora_b_stacked
[
0
][
index
,
:,
:
sliced_w2_lora_b
.
shape
[
1
],
:
sliced_w2_lora_b
.
shape
[
2
]
].
copy_
(
sliced_w2_lora_b
,
non_blocking
=
True
)
@
property
def
_shared_experts
(
self
):
return
self
.
base_layer
.
_shared_experts
def
w13_input_size
(
self
):
"""
Full size
"""
return
self
.
w13_lora_a_stacked
[
0
].
shape
[
-
1
]
@
property
def
quant_method
(
self
):
return
self
.
base_layer
.
quant_method
def
w13_output_size
(
self
):
"""
Full size
"""
return
self
.
w13_lora_b_stacked
[
0
].
shape
[
-
2
]
*
self
.
tp_size
@
property
def
is_internal_router
(
self
)
->
bool
:
return
self
.
base_layer
.
is_internal_router
def
w2_input_size
(
self
):
"""
Full size
"""
return
self
.
w2_lora_a_stacked
[
0
].
shape
[
-
1
]
*
self
.
tp_size
@
property
def
w2_output_size
(
self
):
"""
Full size
"""
return
self
.
w2_lora_a_stacked
[
0
].
shape
[
-
2
]
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
,
)
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
return
type
(
source_layer
)
is
FusedMoE
and
len
(
packed_modules_list
)
==
1
vllm/lora/layers/logits_processor.py
View file @
1073ba68
...
...
@@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
assert
isinstance
(
lora_a
,
torch
.
Tensor
)
assert
isinstance
(
lora_b
,
torch
.
Tensor
)
self
.
reset_lora
(
index
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
0
],
:
lora_a
.
shape
[
1
]].
copy_
(
lora_a
,
non_blocking
=
True
...
...
vllm/lora/layers/vocal_parallel_embedding.py
View file @
1073ba68
...
...
@@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
):
assert
isinstance
(
lora_a
,
torch
.
Tensor
)
assert
isinstance
(
lora_b
,
torch
.
Tensor
)
self
.
reset_lora
(
index
)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self
.
lora_a_stacked
[
index
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
...
...
vllm/lora/models.py
View file @
1073ba68
...
...
@@ -22,11 +22,13 @@ from vllm.lora.utils import (
from_layer_logits_processor
,
get_supported_lora_modules
,
is_base_embeddding_weights
,
is_moe_model
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
process_packed_modules_mapping
,
replace_submodule
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.interfaces
import
is_pooling_model
...
...
@@ -356,7 +358,11 @@ class LoRAModelManager:
self
.
modules
:
dict
[
str
,
BaseLayerWithLoRA
]
=
{}
# Dict instead of a set for compatibility with LRUCache.
self
.
_last_mapping
:
LoRAMapping
|
None
=
None
self
.
_is_3d_moe_model
=
is_moe_model
(
self
.
model
)
and
hasattr
(
self
.
model
,
"is_3d_moe_weight"
)
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
def
__len__
(
self
)
->
int
:
...
...
@@ -400,22 +406,36 @@ class LoRAModelManager:
self
.
lora_index_to_id
[
index
]
=
lora_model
.
id
for
module_name
,
module
in
self
.
modules
.
items
():
module_lora
=
self
.
_get_lora_layer_weights
(
lora_model
,
module_name
)
if
module_lora
:
if
not
module_lora
:
module
.
reset_lora
(
index
)
continue
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if
isinstance
(
module
,
FusedMoEWithLoRA
)
and
torch
.
is_tensor
(
module_lora
.
lora_a
):
# Handle
FSDP
file format where experts.base_layer is the
# Handle
PEFT
file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora
=
self
.
_get_lora_layer_weights
(
lora_model
,
module_name
+
".base_layer"
)
assert
gate_up_proj_lora
is
not
None
assert
module_lora
is
not
None
down_proj_lora
=
module_lora
# FIXME Edge case where LoRA is not added to gate_up_proj
# or down_proj
assert
gate_up_proj_lora
is
not
None
assert
down_proj_lora
is
not
None
if
self
.
_is_3d_moe_model
:
module_lora
.
lora_a
=
[
gate_up_proj_lora
.
lora_a
,
down_proj_lora
.
lora_a
,
]
module_lora
.
lora_b
=
[
gate_up_proj_lora
.
lora_b
,
down_proj_lora
.
lora_b
,
]
else
:
# Some 3D MoE models haven't added the `is_3d_moe_weight`
# attribute yet, so fallback here
num_experts
=
module_lora
.
lora_a
.
shape
[
0
]
//
module_lora
.
rank
gate_proj_a
=
gate_up_proj_lora
.
lora_a
.
chunk
(
num_experts
,
dim
=
0
)
...
...
@@ -444,14 +464,12 @@ class LoRAModelManager:
module_lora
.
lora_a
=
lora_a
module_lora
.
lora_b
=
lora_b
module
.
set_lora
(
index
,
module_lora
.
lora_a
,
module_lora
.
lora_b
,
)
else
:
module
.
reset_lora
(
index
)
return
True
def
_deactivate_adapter
(
self
,
lora_id
:
int
):
...
...
@@ -512,6 +530,13 @@ class LoRAModelManager:
continue
parts
=
module_name
.
split
(
"."
)[
-
1
]
packed_moduled_lst
=
self
.
packed_modules_mapping
.
get
(
parts
,
[])
if
isinstance
(
module
,
FusedMoE
):
# packed_moduled_lst is used here to just determine whether to
# instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
# difference between these two LoRA layers is whether the
# LoRA weights of w1 and w3 have already been fused on disk.
packed_moduled_lst
=
[
"w13"
]
if
self
.
_is_3d_moe_model
else
[
"w1"
,
"w3"
]
new_module
=
replace_submodule
(
self
.
model
,
module_name
,
...
...
@@ -560,6 +585,7 @@ class LoRAModelManager:
self
.
_register_packed_modules
(
module_name
)
# All lora layers share the same punica_wrapper based on reference.
new_module
.
set_mapping
(
self
.
punica_wrapper
)
pass
def
register_module
(
self
,
module_name
:
str
,
module
:
"BaseLayerWithLoRA"
):
assert
isinstance
(
module
,
BaseLayerWithLoRA
),
(
...
...
@@ -605,6 +631,30 @@ class LoRAModelManager:
module
.
lora_a_stacked
[
0
].
dtype
,
"cpu"
,
)
model
.
loras
[
module_name
]
=
lora
elif
module
.
__class__
.
__name__
==
"FusedMoE3DWithLoRA"
:
# Case for 3D moe model
# w2
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
,
module
.
w2_input_size
,
module
.
w2_output_size
,
rank
*
module
.
w2_lora_a_stacked
[
0
].
shape
[
1
],
# rank*num_experts
module
.
w2_lora_a_stacked
[
0
].
dtype
,
"cpu"
,
)
model
.
loras
[
module_name
]
=
lora
# w13
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
,
module
.
w13_input_size
,
module
.
w13_output_size
,
rank
*
module
.
w13_lora_a_stacked
[
0
].
shape
[
1
],
# rank*num_experts
module
.
w13_lora_a_stacked
[
0
].
dtype
,
"cpu"
,
)
model
.
loras
[
module_name
+
".base_layer"
]
=
lora
else
:
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
,
...
...
@@ -614,6 +664,7 @@ class LoRAModelManager:
module
.
lora_a_stacked
[
0
].
dtype
,
"cpu"
,
)
model
.
loras
[
module_name
]
=
lora
else
:
parts
=
module_name
.
split
(
"."
)
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
...
...
vllm/lora/utils.py
View file @
1073ba68
...
...
@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
...
...
@@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithShardedLoRA
,
RowParallelLinearWithShardedLoRA
,
FusedMoEWithLoRA
,
FusedMoE3DWithLoRA
,
}
...
...
@@ -288,9 +290,11 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping
=
get_packed_modules_mapping
(
model
)
if
not
hasattr
(
model
,
"is_3d_moe_weight"
):
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping
[
"experts"
]
=
[
weight_name
.
rstrip
(
"."
)
for
_
,
weight_name
,
_
,
_
in
moe_packed_mapping
weight_name
.
rstrip
(
"."
)
for
_
,
weight_name
,
_
,
_
in
moe_packed_mapping
]
return
packed_modules_mapping
...
...
vllm/model_executor/models/gpt_oss.py
View file @
1073ba68
...
...
@@ -656,6 +656,7 @@ class GptOssModel(nn.Module):
class
GptOssForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsEagle3
,
SupportsLoRA
):
is_3d_moe_weight
:
bool
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
hf_to_vllm_mapper
=
WeightsMapper
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment