Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa7f37cc
Unverified
Commit
aa7f37cc
authored
Jan 19, 2026
by
danisereb
Committed by
GitHub
Jan 19, 2026
Browse files
Add support for LoRA adapters in Nemotron-H models (#30802)
Signed-off-by:
Daniel Serebrenik
<
daserebrenik@nvidia.com
>
parent
c88860d7
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
497 additions
and
27 deletions
+497
-27
tests/lora/test_layers.py
tests/lora/test_layers.py
+297
-0
vllm/lora/layers/__init__.py
vllm/lora/layers/__init__.py
+2
-0
vllm/lora/layers/column_parallel_linear.py
vllm/lora/layers/column_parallel_linear.py
+85
-4
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+25
-17
vllm/lora/lora_weights.py
vllm/lora/lora_weights.py
+26
-3
vllm/lora/model_manager.py
vllm/lora/model_manager.py
+47
-3
vllm/lora/utils.py
vllm/lora/utils.py
+6
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+5
-0
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+1
-0
vllm/model_executor/models/nemotron_h.py
vllm/model_executor/models/nemotron_h.py
+3
-0
No files found.
tests/lora/test_layers.py
View file @
aa7f37cc
...
...
@@ -17,6 +17,7 @@ from vllm.lora.layers import (
ColumnParallelLinearWithShardedLoRA
,
LogitsProcessorWithLoRA
,
LoRAMapping
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithLoRA
,
...
...
@@ -850,6 +851,116 @@ def test_column_parallel_packed(
torch
.
testing
.
assert_close
(
lora_result
,
expected_result
,
rtol
=
rtol
,
atol
=
atol
)
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_slices"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_merged_column_parallel_variable_slice
(
default_vllm_config
,
dist_init
,
num_loras
,
num_slices
,
device
,
stage
)
->
None
:
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
set_device
(
device
)
max_loras
=
8
torch
.
set_default_device
(
device
)
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
punica_wrapper
=
get_punica_wrapper
(
8192
,
256
,
device
,
lora_config
=
lora_config
)
# Set number of output slices
output_sizes
=
[
1024
+
i
*
256
for
i
in
range
(
num_slices
)]
total_output
=
sum
(
output_sizes
)
def
create_layer
():
# Create linear layer
linear
=
MergedColumnParallelLinear
(
4096
,
output_sizes
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
# Create linear layer with LoRA adapter
lora_linear
=
MergedColumnParallelLinearVariableSliceWithLoRA
(
linear
)
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
lora_linear
for
i
in
range
(
NUM_RANDOM_SEEDS
):
set_random_seed
(
i
)
id_to_index
=
get_random_id_to_index
(
num_loras
,
max_loras
)
linear
,
lora_linear
=
create_layer
()
lora_linear
.
set_mapping
(
punica_wrapper
)
# Populate LoRA weights
lora_dict
,
sublora_dict
=
{},
{}
for
slot_idx
,
lora_id
in
enumerate
(
id_to_index
):
if
lora_id
is
not
None
:
# Create random LoRA weights
lora_a
=
torch
.
rand
(
8
,
4096
,
dtype
=
torch
.
float16
,
device
=
device
)
lora_b
=
torch
.
rand
(
total_output
,
8
,
dtype
=
torch
.
float16
,
device
=
device
)
lora_linear
.
set_lora
(
slot_idx
,
lora_a
,
lora_b
)
lora_dict
[
lora_id
]
=
(
lora_a
,
lora_b
)
# Split lora_b for expected computation
sublora_dict
[
lora_id
]
=
torch
.
split
(
lora_b
,
output_sizes
,
dim
=
0
)
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
device
=
device
,
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
punica_wrapper
.
update_metadata
(
lora_mapping
,
id_to_index
,
max_loras
,
512
)
# Compute LoRA result
lora_result
=
lora_linear
(
torch
.
cat
(
inputs
))[
0
]
# Compute expected result
expected_results
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
result
=
linear
(
input_
)[
0
]
lora_a
,
_
=
lora_dict
[
lora_id
]
offset
=
0
# Compute expected result for each sublora
for
lora_b_slice
in
sublora_dict
[
lora_id
]:
sz
=
lora_b_slice
.
shape
[
0
]
result
[:,
offset
:
offset
+
sz
]
+=
input_
@
lora_a
.
T
@
lora_b_slice
.
T
offset
+=
sz
expected_results
.
append
(
result
)
# Check that the LoRA result is close to the expected result
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
torch
.
testing
.
assert_close
(
lora_result
,
torch
.
cat
(
expected_results
),
rtol
=
rtol
,
atol
=
atol
)
# Reset LoRA weights and check results with zero LoRA weights
for
slot_idx
in
range
(
max_loras
):
lora_linear
.
reset_lora
(
slot_idx
)
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
active_lora_ids
=
[
0
],
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
device
=
device
,
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
punica_wrapper
.
update_metadata
(
lora_mapping
,
id_to_index
,
max_loras
,
512
)
# After resetting LoRA weights,
# lora_linear should behave like the base linear layer
lora_result
=
lora_linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
torch
.
testing
.
assert_close
(
lora_result
,
expected_result
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS
))
...
...
@@ -1119,3 +1230,189 @@ def test_get_masked_input_and_mask():
assert
torch
.
equal
(
modified_x_rank_3
,
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
4
])
)
def
test_variable_slice_lora_class_selection
(
default_vllm_config
,
dist_init
):
"""Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
only for nemotron-h style models (checkpoint has single weight but layer
has 3+ output slices).
This verifies that from_layer selects
MergedColumnParallelLinearVariableSliceWithLoRA
before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
"""
from
vllm.lora.utils
import
from_layer
lora_config
=
LoRAConfig
(
max_loras
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
# Case 1: MergedColumnParallelLinear with 3+ output sizes and
# packed_modules_list with 1 item (nemotron-h style)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
layer_3_slices
=
MergedColumnParallelLinear
(
4096
,
[
1024
,
1280
,
1536
],
bias
=
False
,
params_dtype
=
torch
.
float16
)
packed_modules_single
=
[
"mlp"
]
assert
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
layer_3_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
"MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"
# ColumnParallelLinearWithLoRA should NOT match 3+ slices
# (its slice_lora_b assumes exactly 2 slices)
assert
not
ColumnParallelLinearWithLoRA
.
can_replace_layer
(
source_layer
=
layer_3_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
(
"ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
"(slice_lora_b assumes 2 slices)"
)
# Verify from_layer selects the correct class (Variable, not base)
selected_layer
=
from_layer
(
layer_3_slices
,
max_loras
=
8
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
)
assert
isinstance
(
selected_layer
,
MergedColumnParallelLinearVariableSliceWithLoRA
),
(
f
"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
f
"for 3+ slices, got
{
type
(
selected_layer
).
__name__
}
"
)
# Case 2: MergedColumnParallelLinear with 2 output sizes and
# packed_modules_list with 1 item (standard gate_up style)
# -> ColumnParallelLinearWithLoRA should be selected
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
layer_2_slices
=
MergedColumnParallelLinear
(
4096
,
[
2048
,
2048
],
bias
=
False
,
params_dtype
=
torch
.
float16
)
assert
ColumnParallelLinearWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
"ColumnParallelLinearWithLoRA should handle 2 slices"
assert
not
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
"MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"
# Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
selected_layer_2
=
from_layer
(
layer_2_slices
,
max_loras
=
8
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
)
assert
isinstance
(
selected_layer_2
,
ColumnParallelLinearWithLoRA
),
(
f
"from_layer should select ColumnParallelLinearWithLoRA "
f
"for 2 slices, got
{
type
(
selected_layer_2
).
__name__
}
"
)
# But NOT the Variable subclass
assert
not
isinstance
(
selected_layer_2
,
MergedColumnParallelLinearVariableSliceWithLoRA
),
(
"from_layer should NOT select "
"MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
)
# Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
packed_modules_three
=
[
"gate_proj"
,
"up_proj"
,
"down_proj"
]
assert
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
layer_3_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_three
,
),
"MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"
# Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
# -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
packed_modules_two
=
[
"gate_proj"
,
"up_proj"
]
assert
not
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_two
,
),
(
"MergedColumnParallelLinearVariableSliceWithLoRA"
" should NOT handle 2 packed modules"
)
assert
MergedColumnParallelLinearWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_two
,
),
"MergedColumnParallelLinearWithLoRA should handle 2 packed modules"
# Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
selected_layer_merged
=
from_layer
(
layer_2_slices
,
max_loras
=
8
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_two
,
)
assert
isinstance
(
selected_layer_merged
,
MergedColumnParallelLinearWithLoRA
),
(
f
"from_layer should select MergedColumnParallelLinearWithLoRA "
f
"for 2 packed modules, got
{
type
(
selected_layer_merged
).
__name__
}
"
)
# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
# -> ColumnParallelLinearWithLoRA should be selected
plain_column_parallel
=
ColumnParallelLinear
(
4096
,
4096
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
assert
ColumnParallelLinearWithLoRA
.
can_replace_layer
(
source_layer
=
plain_column_parallel
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
"ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"
assert
not
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
plain_column_parallel
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
),
(
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle plain ColumnParallelLinear"
)
# Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
selected_plain
=
from_layer
(
plain_column_parallel
,
max_loras
=
8
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_single
,
)
assert
isinstance
(
selected_plain
,
ColumnParallelLinearWithLoRA
),
(
f
"from_layer should select ColumnParallelLinearWithLoRA "
f
"for plain ColumnParallelLinear, got
{
type
(
selected_plain
).
__name__
}
"
)
# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
# and empty packed_modules_list
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
assert
not
ColumnParallelLinearWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
[],
),
"ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"
assert
not
MergedColumnParallelLinearVariableSliceWithLoRA
.
can_replace_layer
(
source_layer
=
layer_2_slices
,
lora_config
=
lora_config
,
packed_modules_list
=
[],
),
(
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle 2 slices even with empty packed_modules_list"
)
vllm/lora/layers/__init__.py
View file @
aa7f37cc
...
...
@@ -4,6 +4,7 @@ from vllm.lora.layers.base import BaseLayerWithLoRA
from
vllm.lora.layers.column_parallel_linear
import
(
ColumnParallelLinearWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithLoRA
,
...
...
@@ -29,6 +30,7 @@ __all__ = [
"ColumnParallelLinearWithShardedLoRA"
,
"MergedColumnParallelLinearWithLoRA"
,
"MergedColumnParallelLinearWithShardedLoRA"
,
"MergedColumnParallelLinearVariableSliceWithLoRA"
,
"MergedQKVParallelLinearWithLoRA"
,
"MergedQKVParallelLinearWithShardedLoRA"
,
"QKVParallelLinearWithLoRA"
,
...
...
vllm/lora/layers/column_parallel_linear.py
View file @
aa7f37cc
...
...
@@ -155,10 +155,19 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
return
type
(
source_layer
)
is
ColumnParallelLinear
or
(
type
(
source_layer
)
is
MergedColumnParallelLinear
and
len
(
packed_modules_list
)
==
1
if
type
(
source_layer
)
is
ColumnParallelLinear
:
return
True
if
type
(
source_layer
)
is
MergedColumnParallelLinear
:
if
len
(
packed_modules_list
)
!=
1
:
return
False
# Exclude layers with 3+ output sizes - those are handled by
# MergedColumnParallelLinearVariableSliceWithLoRA since this
# class's slice_lora_b assumes exactly 2 slices.
return
not
(
hasattr
(
source_layer
,
"output_sizes"
)
and
len
(
source_layer
.
output_sizes
)
>=
3
)
return
False
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
...
...
@@ -575,3 +584,75 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
model_config
=
model_config
,
decorate
=
False
,
)
class
MergedColumnParallelLinearVariableSliceWithLoRA
(
MergedColumnParallelLinearWithLoRA
):
"""MergedColumnParallelLinear with variable number of slices (3+).
This handles cases where the checkpoint has a single weight for the whole
module (not split into slices), but the layer itself has multiple slices.
"""
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if
type
(
source_layer
)
is
not
MergedColumnParallelLinear
:
return
False
# If packed_modules_list has 3+ items, use this class
if
len
(
packed_modules_list
)
>=
3
:
return
True
# If packed_modules_list has exactly 2 items, let
# MergedColumnParallelLinearWithLoRA handle it
if
len
(
packed_modules_list
)
==
2
:
return
False
# If packed_modules_list is empty or has 1 item,
# check the layer's output_sizes.
# This handles cases where the checkpoint has a single weight
# but the layer has multiple slices (3+)
return
(
hasattr
(
source_layer
,
"output_sizes"
)
and
len
(
source_layer
.
output_sizes
)
>=
3
)
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
lora_b
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
):
"""Override to handle single tensor weights
that need to be split into slices."""
self
.
reset_lora
(
index
)
# Handle case where checkpoint has single tensor weights
# lora_a shape: (rank, input_size) - same for all slices, duplicate it
if
isinstance
(
lora_a
,
torch
.
Tensor
):
lora_a
=
[
lora_a
]
*
self
.
n_slices
# lora_b shape: (total_output_size, rank) -
# split along dim 0 based on output_sizes
if
isinstance
(
lora_b
,
torch
.
Tensor
):
output_sizes
=
self
.
base_layer
.
output_sizes
lora_b_list
=
[]
start_idx
=
0
for
output_size
in
output_sizes
:
end_idx
=
start_idx
+
output_size
lora_b_list
.
append
(
lora_b
[
start_idx
:
end_idx
,
:])
start_idx
=
end_idx
lora_b
=
lora_b_list
# Now call parent's set_lora which expects lists
super
().
set_lora
(
index
,
lora_a
,
lora_b
)
vllm/lora/layers/fused_moe.py
View file @
aa7f37cc
...
...
@@ -52,7 +52,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
device
=
_get_lora_device
(
base_layer
)
self
.
_w13_slices
=
2
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self
.
_w13_slices
=
2
if
base_layer
.
moe_config
.
is_act_and_mul
else
1
self
.
_inject_lora_into_fused_moe
()
def
_normalize_keys
(
self
,
config
:
dict
[
str
,
int
|
None
])
->
dict
[
str
,
int
|
None
]:
...
...
@@ -400,7 +402,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
lora_b_stacked
=
[]
for
lora_id
in
range
(
max_loras
):
for
experts_id
in
range
(
self
.
base_layer
.
local_num_experts
):
# gate_proj,down_proj,up_proj
# For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
# For non-gated MoE: up_proj (w1), down_proj (w2)
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
0
][
lora_id
][
experts_id
]
)
...
...
@@ -415,6 +418,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
w2_lora_b_stacked
[
0
][
lora_id
][
experts_id
]
)
# Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
if
self
.
_w13_slices
==
2
:
self
.
lora_a_stacked
.
append
(
self
.
w13_lora_a_stacked
[
1
][
lora_id
][
experts_id
]
)
...
...
@@ -515,8 +520,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
slliced_w1_lora_a
=
self
.
_slice_w13_a
(
w1_lora_a
)
slliced_w1_lora_b
=
self
.
_slice_w13_b
(
w1_lora_b
)
slliced_w3_lora_a
=
self
.
_slice_w13_a
(
w3_lora_a
)
slliced_w3_lora_b
=
self
.
_slice_w13_b
(
w3_lora_b
)
sliced_w2_lora_a
=
self
.
_slice_w2_a
(
w2_lora_a
)
sliced_w2_lora_b
=
self
.
_slice_w2_b
(
w2_lora_b
)
...
...
@@ -525,14 +528,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index
,
:,
:
slliced_w1_lora_a
.
shape
[
1
],
:
slliced_w1_lora_a
.
shape
[
2
]
].
copy_
(
slliced_w1_lora_a
,
non_blocking
=
True
)
self
.
w13_lora_a_stacked
[
1
][
index
,
:,
:
slliced_w3_lora_a
.
shape
[
1
],
:
slliced_w3_lora_a
.
shape
[
2
]
].
copy_
(
slliced_w3_lora_a
,
non_blocking
=
True
)
self
.
w13_lora_b_stacked
[
0
][
index
,
:,
:
slliced_w1_lora_b
.
shape
[
1
],
:
slliced_w1_lora_b
.
shape
[
2
]
].
copy_
(
slliced_w1_lora_b
,
non_blocking
=
True
)
# Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
if
self
.
_w13_slices
==
2
:
slliced_w3_lora_a
=
self
.
_slice_w13_a
(
w3_lora_a
)
slliced_w3_lora_b
=
self
.
_slice_w13_b
(
w3_lora_b
)
self
.
w13_lora_a_stacked
[
1
][
index
,
:,
:
slliced_w3_lora_a
.
shape
[
1
],
:
slliced_w3_lora_a
.
shape
[
2
]
].
copy_
(
slliced_w3_lora_a
,
non_blocking
=
True
)
self
.
w13_lora_b_stacked
[
1
][
index
,
:,
:
slliced_w3_lora_b
.
shape
[
1
],
:
slliced_w3_lora_b
.
shape
[
2
]
].
copy_
(
slliced_w3_lora_b
,
non_blocking
=
True
)
...
...
vllm/lora/lora_weights.py
View file @
aa7f37cc
...
...
@@ -154,7 +154,10 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@
classmethod
def
pack_moe
(
cls
,
loras
:
GenericSequence
[
Optional
[
"LoRALayerWeights"
]],
module_name
:
str
cls
,
loras
:
GenericSequence
[
Optional
[
"LoRALayerWeights"
]],
module_name
:
str
,
is_non_gated_moe
:
bool
=
False
,
)
->
"PackedLoRALayerWeights"
:
"""Pack a list of LoRAs into a single LoRA.
...
...
@@ -177,6 +180,11 @@ class PackedLoRALayerWeights(LoRALayerWeights):
w1_lora
=
loras
[
eid
*
3
]
w2_lora
=
loras
[
eid
*
3
+
1
]
w3_lora
=
loras
[
eid
*
3
+
2
]
# For non-gated MoE, w3 is not used, so we use w1's LoRA weights
# This is determined by checking the expert mapping (get_expert_mapping)
# which indicates when ckpt_up_proj_name is empty.
if
w3_lora
is
None
and
is_non_gated_moe
:
w3_lora
=
w1_lora
assert
w1_lora
is
not
None
assert
w2_lora
is
not
None
assert
w3_lora
is
not
None
...
...
@@ -191,9 +199,23 @@ class PackedLoRALayerWeights(LoRALayerWeights):
w1_lora_a
=
torch
.
stack
(
w1_lora_a_lst
,
dim
=
0
)
# (num_experts,rank,input_size)
w2_lora_a
=
torch
.
stack
(
w2_lora_a_lst
,
dim
=
0
)
w3_lora_a
=
torch
.
stack
(
w3_lora_a_lst
,
dim
=
0
)
w1_lora_b
=
torch
.
stack
(
w1_lora_b_lst
,
dim
=
0
)
# (num_experts,output_size,rank)
w2_lora_b
=
torch
.
stack
(
w2_lora_b_lst
,
dim
=
0
)
# All w1, w2, w3 have the same scaling factor.
scaling
=
lora_alpha
/
rank
last_scaling
=
scaling
if
is_non_gated_moe
:
# For non-gated MoE, reuse w1 tensors for w3 to avoid memory waste
# w3_lora_a_lst and w3_lora_b_lst are not relevant in this case
w3_lora_a
=
w1_lora_a
w3_lora_b
=
w1_lora_b
# For non-gated MoE, avoid double-scaling by setting w3's scaling to 1.
last_scaling
=
1.0
else
:
w3_lora_a
=
torch
.
stack
(
w3_lora_a_lst
,
dim
=
0
)
w3_lora_b
=
torch
.
stack
(
w3_lora_b_lst
,
dim
=
0
)
obj
=
cls
(
...
...
@@ -202,6 +224,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[
lora_alpha
,
lora_alpha
,
lora_alpha
],
[
w1_lora_a
,
w2_lora_a
,
w3_lora_a
],
[
w1_lora_b
,
w2_lora_b
,
w3_lora_b
],
scaling
=
[
scaling
,
scaling
,
last_scaling
],
)
return
obj
...
...
vllm/lora/model_manager.py
View file @
aa7f37cc
...
...
@@ -104,7 +104,9 @@ class LoRAModelManager:
self
.
modules
:
dict
[
str
,
BaseLayerWithLoRA
]
=
{}
# Dict instead of a set for compatibility with LRUCache.
self
.
_last_mapping
:
LoRAMapping
|
None
=
None
self
.
_is_3d_moe_model
=
is_moe_model
(
self
.
model
)
and
self
.
model
.
is_3d_moe_weight
is_moe
=
is_moe_model
(
self
.
model
)
self
.
_is_3d_moe_model
=
is_moe
and
self
.
model
.
is_3d_moe_weight
self
.
_is_non_gated_moe
=
is_moe
and
self
.
model
.
is_non_gated_moe
self
.
_init_punica_wrapper
(
max_num_batched_tokens
,
vllm_config
)
self
.
_create_lora_modules
()
...
...
@@ -339,6 +341,20 @@ class LoRAModelManager:
)
continue
# TODO: Remove this restriction
# peft error when generating LoRA adapter with "gate" module:
# "Target module NemotronHTopkRouter() is not supported."
# Working LoRA adapter was created using peft with:
# LoraConfig(target_modules="all-linear", ...)
if
self
.
_is_non_gated_moe
and
module_name
.
endswith
(
"mixer.gate"
):
logger
.
debug_once
(
"LoRA is not supported for non-gated MoE gate module."
" %s will be ignored."
,
module_name
,
scope
=
"local"
,
)
continue
parts
=
module_name
.
split
(
"."
)[
-
1
]
packed_moduled_lst
=
self
.
packed_modules_mapping
.
get
(
parts
,
[])
if
isinstance
(
module
,
FusedMoE
):
...
...
@@ -405,6 +421,22 @@ class LoRAModelManager:
)
self
.
modules
[
module_name
]
=
module
@
staticmethod
def
_pad_lora_pairs_to_triplets
(
loras
:
list
[
LoRALayerWeights
|
None
],
)
->
list
[
LoRALayerWeights
|
None
]:
"""Pad LoRA weight pairs to triplets for non-gated MoE.
For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
padded to triplets (w1, w2, None) to match pack_moe expectations.
"""
assert
len
(
loras
)
%
2
==
0
,
"Expected pairs of LoRA weights for non-gated MoE."
padded
:
list
[
LoRALayerWeights
|
None
]
=
[]
for
i
in
range
(
0
,
len
(
loras
),
2
):
padded
.
extend
(
loras
[
i
:
i
+
2
])
padded
.
append
(
None
)
return
padded
def
create_dummy_lora
(
self
,
lora_id
:
int
,
...
...
@@ -491,7 +523,13 @@ class LoRAModelManager:
)
subloras
.
append
(
lora
)
if
module
.
__class__
.
__name__
==
"FusedMoEWithLoRA"
:
lora
=
PackedLoRALayerWeights
.
pack_moe
(
subloras
,
module_name
)
# For non-gated MoE, pad subloras to 3 elements per expert
# to match pack_moe expectations (w1, w2, None for w3)
if
self
.
_is_non_gated_moe
and
len
(
subloras
)
>
0
:
subloras
=
self
.
_pad_lora_pairs_to_triplets
(
subloras
)
lora
=
PackedLoRALayerWeights
.
pack_moe
(
subloras
,
module_name
,
is_non_gated_moe
=
self
.
_is_non_gated_moe
)
else
:
lora
=
PackedLoRALayerWeights
.
pack
(
subloras
)
model
.
loras
[
module_name
]
=
lora
...
...
@@ -559,8 +597,14 @@ class LoRAModelManager:
if
lora_model
.
check_lora_name
(
module_name
):
module_name
=
replaced_module_name
if
module_name
.
endswith
(
".experts"
):
if
self
.
_is_non_gated_moe
and
len
(
replacement_loras
)
>
0
:
replacement_loras
=
self
.
_pad_lora_pairs_to_triplets
(
replacement_loras
)
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack_moe
(
replacement_loras
,
module_name
replacement_loras
,
module_name
,
is_non_gated_moe
=
self
.
_is_non_gated_moe
,
)
else
:
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
...
...
vllm/lora/utils.py
View file @
aa7f37cc
...
...
@@ -25,6 +25,7 @@ from vllm.lora.layers import (
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithLoRA
,
...
...
@@ -68,6 +69,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
ColumnParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedQKVParallelLinearWithShardedLoRA
,
RowParallelLinearWithShardedLoRA
,
FusedMoEWithLoRA
,
...
...
@@ -266,9 +268,13 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
packed_modules_mapping
=
get_packed_modules_mapping
(
model
)
if
not
model
.
is_3d_moe_weight
:
# 3D MoE LoRA does not need `packed_modules_mapping`
# Filter out malformed entries: non-gated MoE has empty
# ckpt_up_proj_name which results in weight_name containing ".."
# (e.g., "experts.0.." instead of "experts.0.layer_name.")
packed_modules_mapping
[
"experts"
]
=
[
weight_name
.
rstrip
(
"."
)
for
_
,
weight_name
,
_
,
_
in
moe_packed_mapping
if
".."
not
in
weight_name
]
return
packed_modules_mapping
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
aa7f37cc
...
...
@@ -227,6 +227,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_deepseek_fp8_block_scale
=
self
.
use_deepseek_fp8_block_scale
,
)
def
moe_sum
(
self
,
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise
NotImplementedError
(
"LoRA is not supported for flashinfer_cutlass_moe"
)
def
flashinfer_cutlass_moe_fp4
(
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/interfaces.py
View file @
aa7f37cc
...
...
@@ -376,6 +376,7 @@ class SupportsLoRA(Protocol):
MRO of your model class.
"""
is_3d_moe_weight
:
ClassVar
[
bool
]
=
False
is_non_gated_moe
:
ClassVar
[
bool
]
=
False
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules
:
ClassVar
[
dict
[
str
,
str
]]
=
{}
...
...
vllm/model_executor/models/nemotron_h.py
View file @
aa7f37cc
...
...
@@ -747,6 +747,9 @@ class NemotronHForCausalLM(
MixtureOfExperts
,
SupportsMambaPrefixCaching
,
):
# Relevant only if self.has_moe is True
is_non_gated_moe
:
bool
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"backbone"
:
"model"
},
orig_to_new_substr
=
{
"A_log"
:
"A"
,
"embeddings"
:
"embed_tokens"
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment