Unverified Commit aa7f37cc authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Add support for LoRA adapters in Nemotron-H models (#30802)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent c88860d7
...@@ -17,6 +17,7 @@ from vllm.lora.layers import ( ...@@ -17,6 +17,7 @@ from vllm.lora.layers import (
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
LoRAMapping, LoRAMapping,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
...@@ -850,6 +851,116 @@ def test_column_parallel_packed( ...@@ -850,6 +851,116 @@ def test_column_parallel_packed(
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("num_slices", [3, 5])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_merged_column_parallel_variable_slice(
default_vllm_config, dist_init, num_loras, num_slices, device, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
max_loras = 8
torch.set_default_device(device)
lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
# Set number of output slices
output_sizes = [1024 + i * 256 for i in range(num_slices)]
total_output = sum(output_sizes)
def create_layer():
# Create linear layer
linear = MergedColumnParallelLinear(
4096, output_sizes, bias=False, params_dtype=torch.float16
)
linear.weight.data = torch.rand_like(linear.weight.data)
# Create linear layer with LoRA adapter
lora_linear = MergedColumnParallelLinearVariableSliceWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_layer()
lora_linear.set_mapping(punica_wrapper)
# Populate LoRA weights
lora_dict, sublora_dict = {}, {}
for slot_idx, lora_id in enumerate(id_to_index):
if lora_id is not None:
# Create random LoRA weights
lora_a = torch.rand(8, 4096, dtype=torch.float16, device=device)
lora_b = torch.rand(total_output, 8, dtype=torch.float16, device=device)
lora_linear.set_lora(slot_idx, lora_a, lora_b)
lora_dict[lora_id] = (lora_a, lora_b)
# Split lora_b for expected computation
sublora_dict[lora_id] = torch.split(lora_b, output_sizes, dim=0)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
# Compute LoRA result
lora_result = lora_linear(torch.cat(inputs))[0]
# Compute expected result
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
lora_a, _ = lora_dict[lora_id]
offset = 0
# Compute expected result for each sublora
for lora_b_slice in sublora_dict[lora_id]:
sz = lora_b_slice.shape[0]
result[:, offset : offset + sz] += input_ @ lora_a.T @ lora_b_slice.T
offset += sz
expected_results.append(result)
# Check that the LoRA result is close to the expected result
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)
# Reset LoRA weights and check results with zero LoRA weights
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
# After resetting LoRA weights,
# lora_linear should behave like the base linear layer
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)) "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
...@@ -1119,3 +1230,189 @@ def test_get_masked_input_and_mask(): ...@@ -1119,3 +1230,189 @@ def test_get_masked_input_and_mask():
assert torch.equal( assert torch.equal(
modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]) modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
) )
def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
"""Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
only for nemotron-h style models (checkpoint has single weight but layer
has 3+ output slices).
This verifies that from_layer selects
MergedColumnParallelLinearVariableSliceWithLoRA
before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
"""
from vllm.lora.utils import from_layer
lora_config = LoRAConfig(max_loras=8, max_lora_rank=8, lora_dtype=torch.float16)
# Case 1: MergedColumnParallelLinear with 3+ output sizes and
# packed_modules_list with 1 item (nemotron-h style)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
layer_3_slices = MergedColumnParallelLinear(
4096, [1024, 1280, 1536], bias=False, params_dtype=torch.float16
)
packed_modules_single = ["mlp"]
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"
# ColumnParallelLinearWithLoRA should NOT match 3+ slices
# (its slice_lora_b assumes exactly 2 slices)
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), (
"ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
"(slice_lora_b assumes 2 slices)"
)
# Verify from_layer selects the correct class (Variable, not base)
selected_layer = from_layer(
layer_3_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(
selected_layer, MergedColumnParallelLinearVariableSliceWithLoRA
), (
f"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
f"for 3+ slices, got {type(selected_layer).__name__}"
)
# Case 2: MergedColumnParallelLinear with 2 output sizes and
# packed_modules_list with 1 item (standard gate_up style)
# -> ColumnParallelLinearWithLoRA should be selected
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
layer_2_slices = MergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
assert ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "ColumnParallelLinearWithLoRA should handle 2 slices"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"
# Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
selected_layer_2 = from_layer(
layer_2_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(selected_layer_2, ColumnParallelLinearWithLoRA), (
f"from_layer should select ColumnParallelLinearWithLoRA "
f"for 2 slices, got {type(selected_layer_2).__name__}"
)
# But NOT the Variable subclass
assert not isinstance(
selected_layer_2, MergedColumnParallelLinearVariableSliceWithLoRA
), (
"from_layer should NOT select "
"MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
)
# Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
packed_modules_three = ["gate_proj", "up_proj", "down_proj"]
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_three,
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"
# Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
# -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
packed_modules_two = ["gate_proj", "up_proj"]
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), (
"MergedColumnParallelLinearVariableSliceWithLoRA"
" should NOT handle 2 packed modules"
)
assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), "MergedColumnParallelLinearWithLoRA should handle 2 packed modules"
# Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
selected_layer_merged = from_layer(
layer_2_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_layer_merged, MergedColumnParallelLinearWithLoRA), (
f"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
)
# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
# -> ColumnParallelLinearWithLoRA should be selected
plain_column_parallel = ColumnParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16
)
assert ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=plain_column_parallel,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=plain_column_parallel,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), (
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle plain ColumnParallelLinear"
)
# Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
selected_plain = from_layer(
plain_column_parallel,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(selected_plain, ColumnParallelLinearWithLoRA), (
f"from_layer should select ColumnParallelLinearWithLoRA "
f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
)
# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
# and empty packed_modules_list
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=[],
), "ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=[],
), (
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle 2 slices even with empty packed_modules_list"
)
...@@ -4,6 +4,7 @@ from vllm.lora.layers.base import BaseLayerWithLoRA ...@@ -4,6 +4,7 @@ from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import ( from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
...@@ -29,6 +30,7 @@ __all__ = [ ...@@ -29,6 +30,7 @@ __all__ = [
"ColumnParallelLinearWithShardedLoRA", "ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA", "MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA", "MergedColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearVariableSliceWithLoRA",
"MergedQKVParallelLinearWithLoRA", "MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA", "MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA", "QKVParallelLinearWithLoRA",
......
...@@ -155,10 +155,19 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -155,10 +155,19 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ColumnParallelLinear or ( if type(source_layer) is ColumnParallelLinear:
type(source_layer) is MergedColumnParallelLinear return True
and len(packed_modules_list) == 1 if type(source_layer) is MergedColumnParallelLinear:
) if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
# MergedColumnParallelLinearVariableSliceWithLoRA since this
# class's slice_lora_b assumes exactly 2 slices.
return not (
hasattr(source_layer, "output_sizes")
and len(source_layer.output_sizes) >= 3
)
return False
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...@@ -575,3 +584,75 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): ...@@ -575,3 +584,75 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
model_config=model_config, model_config=model_config,
decorate=False, decorate=False,
) )
class MergedColumnParallelLinearVariableSliceWithLoRA(
MergedColumnParallelLinearWithLoRA
):
"""MergedColumnParallelLinear with variable number of slices (3+).
This handles cases where the checkpoint has a single weight for the whole
module (not split into slices), but the layer itself has multiple slices.
"""
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not MergedColumnParallelLinear:
return False
# If packed_modules_list has 3+ items, use this class
if len(packed_modules_list) >= 3:
return True
# If packed_modules_list has exactly 2 items, let
# MergedColumnParallelLinearWithLoRA handle it
if len(packed_modules_list) == 2:
return False
# If packed_modules_list is empty or has 1 item,
# check the layer's output_sizes.
# This handles cases where the checkpoint has a single weight
# but the layer has multiple slices (3+)
return (
hasattr(source_layer, "output_sizes")
and len(source_layer.output_sizes) >= 3
)
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Override to handle single tensor weights
that need to be split into slices."""
self.reset_lora(index)
# Handle case where checkpoint has single tensor weights
# lora_a shape: (rank, input_size) - same for all slices, duplicate it
if isinstance(lora_a, torch.Tensor):
lora_a = [lora_a] * self.n_slices
# lora_b shape: (total_output_size, rank) -
# split along dim 0 based on output_sizes
if isinstance(lora_b, torch.Tensor):
output_sizes = self.base_layer.output_sizes
lora_b_list = []
start_idx = 0
for output_size in output_sizes:
end_idx = start_idx + output_size
lora_b_list.append(lora_b[start_idx:end_idx, :])
start_idx = end_idx
lora_b = lora_b_list
# Now call parent's set_lora which expects lists
super().set_lora(index, lora_a, lora_b)
...@@ -52,7 +52,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -52,7 +52,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer) self.device = _get_lora_device(base_layer)
self._w13_slices = 2 # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
self._inject_lora_into_fused_moe() self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
...@@ -400,7 +402,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -400,7 +402,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked = [] self.lora_b_stacked = []
for lora_id in range(max_loras): for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts): for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
# For non-gated MoE: up_proj (w1), down_proj (w2)
self.lora_a_stacked.append( self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id] self.w13_lora_a_stacked[0][lora_id][experts_id]
) )
...@@ -415,12 +418,14 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -415,12 +418,14 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked[0][lora_id][experts_id] self.w2_lora_b_stacked[0][lora_id][experts_id]
) )
self.lora_a_stacked.append( # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
self.w13_lora_a_stacked[1][lora_id][experts_id] if self._w13_slices == 2:
) self.lora_a_stacked.append(
self.lora_b_stacked.append( self.w13_lora_a_stacked[1][lora_id][experts_id]
self.w13_lora_b_stacked[1][lora_id][experts_id] )
) self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor: def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
""" """
...@@ -515,8 +520,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -515,8 +520,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a) slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b) slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a) sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b) sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
...@@ -525,17 +528,22 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -525,17 +528,22 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2] index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
].copy_(slliced_w1_lora_a, non_blocking=True) ].copy_(slliced_w1_lora_a, non_blocking=True)
self.w13_lora_a_stacked[1][
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(slliced_w3_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][ self.w13_lora_b_stacked[0][
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2] index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(slliced_w1_lora_b, non_blocking=True) ].copy_(slliced_w1_lora_b, non_blocking=True)
self.w13_lora_b_stacked[1][ # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2] if self._w13_slices == 2:
].copy_(slliced_w3_lora_b, non_blocking=True) slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
self.w13_lora_a_stacked[1][
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(slliced_w3_lora_a, non_blocking=True)
self.w13_lora_b_stacked[1][
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][ self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2] index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
......
...@@ -154,7 +154,10 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -154,7 +154,10 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod @classmethod
def pack_moe( def pack_moe(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str cls,
loras: GenericSequence[Optional["LoRALayerWeights"]],
module_name: str,
is_non_gated_moe: bool = False,
) -> "PackedLoRALayerWeights": ) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
...@@ -177,6 +180,11 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -177,6 +180,11 @@ class PackedLoRALayerWeights(LoRALayerWeights):
w1_lora = loras[eid * 3] w1_lora = loras[eid * 3]
w2_lora = loras[eid * 3 + 1] w2_lora = loras[eid * 3 + 1]
w3_lora = loras[eid * 3 + 2] w3_lora = loras[eid * 3 + 2]
# For non-gated MoE, w3 is not used, so we use w1's LoRA weights
# This is determined by checking the expert mapping (get_expert_mapping)
# which indicates when ckpt_up_proj_name is empty.
if w3_lora is None and is_non_gated_moe:
w3_lora = w1_lora
assert w1_lora is not None assert w1_lora is not None
assert w2_lora is not None assert w2_lora is not None
assert w3_lora is not None assert w3_lora is not None
...@@ -191,10 +199,24 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -191,10 +199,24 @@ class PackedLoRALayerWeights(LoRALayerWeights):
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size) w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0) w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank) w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0) w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
# All w1, w2, w3 have the same scaling factor.
scaling = lora_alpha / rank
last_scaling = scaling
if is_non_gated_moe:
# For non-gated MoE, reuse w1 tensors for w3 to avoid memory waste
# w3_lora_a_lst and w3_lora_b_lst are not relevant in this case
w3_lora_a = w1_lora_a
w3_lora_b = w1_lora_b
# For non-gated MoE, avoid double-scaling by setting w3's scaling to 1.
last_scaling = 1.0
else:
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
obj = cls( obj = cls(
module_name, module_name,
...@@ -202,6 +224,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -202,6 +224,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora_alpha, lora_alpha, lora_alpha], [lora_alpha, lora_alpha, lora_alpha],
[w1_lora_a, w2_lora_a, w3_lora_a], [w1_lora_a, w2_lora_a, w3_lora_a],
[w1_lora_b, w2_lora_b, w3_lora_b], [w1_lora_b, w2_lora_b, w3_lora_b],
scaling=[scaling, scaling, last_scaling],
) )
return obj return obj
......
...@@ -104,7 +104,9 @@ class LoRAModelManager: ...@@ -104,7 +104,9 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight is_moe = is_moe_model(self.model)
self._is_3d_moe_model = is_moe and self.model.is_3d_moe_weight
self._is_non_gated_moe = is_moe and self.model.is_non_gated_moe
self._init_punica_wrapper(max_num_batched_tokens, vllm_config) self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
self._create_lora_modules() self._create_lora_modules()
...@@ -339,6 +341,20 @@ class LoRAModelManager: ...@@ -339,6 +341,20 @@ class LoRAModelManager:
) )
continue continue
# TODO: Remove this restriction
# peft error when generating LoRA adapter with "gate" module:
# "Target module NemotronHTopkRouter() is not supported."
# Working LoRA adapter was created using peft with:
# LoraConfig(target_modules="all-linear", ...)
if self._is_non_gated_moe and module_name.endswith("mixer.gate"):
logger.debug_once(
"LoRA is not supported for non-gated MoE gate module."
" %s will be ignored.",
module_name,
scope="local",
)
continue
parts = module_name.split(".")[-1] parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, []) packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
if isinstance(module, FusedMoE): if isinstance(module, FusedMoE):
...@@ -405,6 +421,22 @@ class LoRAModelManager: ...@@ -405,6 +421,22 @@ class LoRAModelManager:
) )
self.modules[module_name] = module self.modules[module_name] = module
@staticmethod
def _pad_lora_pairs_to_triplets(
loras: list[LoRALayerWeights | None],
) -> list[LoRALayerWeights | None]:
"""Pad LoRA weight pairs to triplets for non-gated MoE.
For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
padded to triplets (w1, w2, None) to match pack_moe expectations.
"""
assert len(loras) % 2 == 0, "Expected pairs of LoRA weights for non-gated MoE."
padded: list[LoRALayerWeights | None] = []
for i in range(0, len(loras), 2):
padded.extend(loras[i : i + 2])
padded.append(None)
return padded
def create_dummy_lora( def create_dummy_lora(
self, self,
lora_id: int, lora_id: int,
...@@ -491,7 +523,13 @@ class LoRAModelManager: ...@@ -491,7 +523,13 @@ class LoRAModelManager:
) )
subloras.append(lora) subloras.append(lora)
if module.__class__.__name__ == "FusedMoEWithLoRA": if module.__class__.__name__ == "FusedMoEWithLoRA":
lora = PackedLoRALayerWeights.pack_moe(subloras, module_name) # For non-gated MoE, pad subloras to 3 elements per expert
# to match pack_moe expectations (w1, w2, None for w3)
if self._is_non_gated_moe and len(subloras) > 0:
subloras = self._pad_lora_pairs_to_triplets(subloras)
lora = PackedLoRALayerWeights.pack_moe(
subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
)
else: else:
lora = PackedLoRALayerWeights.pack(subloras) lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora model.loras[module_name] = lora
...@@ -559,8 +597,14 @@ class LoRAModelManager: ...@@ -559,8 +597,14 @@ class LoRAModelManager:
if lora_model.check_lora_name(module_name): if lora_model.check_lora_name(module_name):
module_name = replaced_module_name module_name = replaced_module_name
if module_name.endswith(".experts"): if module_name.endswith(".experts"):
if self._is_non_gated_moe and len(replacement_loras) > 0:
replacement_loras = self._pad_lora_pairs_to_triplets(
replacement_loras
)
lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe( lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
replacement_loras, module_name replacement_loras,
module_name,
is_non_gated_moe=self._is_non_gated_moe,
) )
else: else:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
......
...@@ -25,6 +25,7 @@ from vllm.lora.layers import ( ...@@ -25,6 +25,7 @@ from vllm.lora.layers import (
FusedMoE3DWithLoRA, FusedMoE3DWithLoRA,
FusedMoEWithLoRA, FusedMoEWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
...@@ -68,6 +69,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { ...@@ -68,6 +69,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedQKVParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA, FusedMoEWithLoRA,
...@@ -266,9 +268,13 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]: ...@@ -266,9 +268,13 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
packed_modules_mapping = get_packed_modules_mapping(model) packed_modules_mapping = get_packed_modules_mapping(model)
if not model.is_3d_moe_weight: if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping` # 3D MoE LoRA does not need `packed_modules_mapping`
# Filter out malformed entries: non-gated MoE has empty
# ckpt_up_proj_name which results in weight_name containing ".."
# (e.g., "experts.0.." instead of "experts.0.layer_name.")
packed_modules_mapping["experts"] = [ packed_modules_mapping["experts"] = [
weight_name.rstrip(".") weight_name.rstrip(".")
for _, weight_name, _, _ in moe_packed_mapping for _, weight_name, _, _ in moe_packed_mapping
if ".." not in weight_name
] ]
return packed_modules_mapping return packed_modules_mapping
......
...@@ -227,6 +227,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -227,6 +227,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
) )
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
def flashinfer_cutlass_moe_fp4( def flashinfer_cutlass_moe_fp4(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -376,6 +376,7 @@ class SupportsLoRA(Protocol): ...@@ -376,6 +376,7 @@ class SupportsLoRA(Protocol):
MRO of your model class. MRO of your model class.
""" """
is_3d_moe_weight: ClassVar[bool] = False is_3d_moe_weight: ClassVar[bool] = False
is_non_gated_moe: ClassVar[bool] = False
# The `embedding_module` and `embedding_padding_modules` # The `embedding_module` and `embedding_padding_modules`
# are empty by default. # are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {} embedding_modules: ClassVar[dict[str, str]] = {}
......
...@@ -747,6 +747,9 @@ class NemotronHForCausalLM( ...@@ -747,6 +747,9 @@ class NemotronHForCausalLM(
MixtureOfExperts, MixtureOfExperts,
SupportsMambaPrefixCaching, SupportsMambaPrefixCaching,
): ):
# Relevant only if self.has_moe is True
is_non_gated_moe: bool = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"}, orig_to_new_prefix={"backbone": "model"},
orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"}, orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment