Unverified Commit a250f1bd authored by ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟's avatar ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 Committed by GitHub
Browse files

[Bugfix] LoRA for DeepSeek V3.2 (#35077)


Signed-off-by: default avatarHollow Man <hollowman@opensuse.org>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 04eac6ba
......@@ -1620,18 +1620,13 @@ def _parallel_worker(
else:
print("F", end="")
finally:
# Note: for some reason DeepEP buffers don't seem to be
# entirely reusable on B200. In order to work around this
# we clear the all2all manager's cache after each testpoint.
cap = current_platform.get_device_capability()
if (
cap is not None
and cap.major == 10
and (
test_config.backend == "deepep_low_latency"
or test_config.backend == "deepep_high_throughput"
)
):
# DeepEP managers are not reliably reusable across many subtests in
# a single worker process. Tear them down after each DeepEP case so
# later subtests do not inherit stale communication state.
if test_config.backend in {
"deepep_low_latency",
"deepep_high_throughput",
}:
torch.accelerator.synchronize()
all2all_manager = get_ep_group().device_communicator.all2all_manager
if all2all_manager is not None:
......
......@@ -44,6 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
get_masked_input_and_mask,
)
from vllm.model_executor.models.deepseek_v2 import DeepSeekV2FusedQkvAProjLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
......@@ -1422,7 +1423,107 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
)
# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
fully_sharded_tp_lora_config = LoRAConfig(
max_loras=8,
max_lora_rank=16,
lora_dtype=torch.float16,
fully_sharded_loras=True,
)
fully_sharded_tp_layer = MergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
fully_sharded_tp_layer.tp_size = 2
assert not MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=fully_sharded_tp_layer,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
), "Generic merged wrapper should reject fully sharded TP layers"
assert MergedColumnParallelLinearWithShardedLoRA.can_replace_layer(
source_layer=fully_sharded_tp_layer,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
), "Sharded merged wrapper should remain eligible for fully sharded TP layers"
selected_fully_sharded_tp_layer = from_layer(
fully_sharded_tp_layer,
max_loras=8,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(
selected_fully_sharded_tp_layer,
MergedColumnParallelLinearWithShardedLoRA,
), (
"from_layer should select MergedColumnParallelLinearWithShardedLoRA "
"for fully sharded TP merged layers, got "
f"{type(selected_fully_sharded_tp_layer).__name__}"
)
# Case 5: DeepSeek's fused_qkv_a_proj should reuse the generic merged
# wrapper while preserving its custom base forward path.
deepseek_fused_layer = DeepSeekV2FusedQkvAProjLinear(
4096, [2048, 2048], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
selected_deepseek_layer = from_layer(
deepseek_fused_layer,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_deepseek_layer, MergedColumnParallelLinearWithLoRA), (
"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for DeepSeek fused_qkv_a_proj, got {type(selected_deepseek_layer).__name__}"
)
fully_sharded_lora_config = LoRAConfig(
max_loras=8,
max_lora_rank=16,
lora_dtype=torch.float16,
fully_sharded_loras=True,
)
selected_fully_sharded_deepseek_layer = from_layer(
deepseek_fused_layer,
max_loras=8,
lora_config=fully_sharded_lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(
selected_fully_sharded_deepseek_layer,
MergedColumnParallelLinearWithLoRA,
), (
"from_layer should keep using MergedColumnParallelLinearWithLoRA "
"for fused_qkv_a_proj when the base layer is effectively unsharded, got "
f"{type(selected_fully_sharded_deepseek_layer).__name__}"
)
# Case 6: Generic subclass of MergedColumnParallelLinear with 2 packed
# modules should still use the generic merged wrapper.
class CustomMergedColumnParallelLinear(MergedColumnParallelLinear):
pass
custom_merged_layer = CustomMergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=custom_merged_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), "MergedColumnParallelLinearWithLoRA should handle subclasses"
selected_custom_layer = from_layer(
custom_merged_layer,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_custom_layer, MergedColumnParallelLinearWithLoRA), (
f"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for subclassed merged layers, got {type(selected_custom_layer).__name__}"
)
# Case 7: Plain ColumnParallelLinear (not merged) - common in many models
# -> ColumnParallelLinearWithLoRA should be selected
plain_column_parallel = ColumnParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16
......@@ -1455,7 +1556,7 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
)
# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
# Case 8: MergedColumnParallelLinear with exactly 2 output sizes
# and empty packed_modules_list
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
......@@ -1473,3 +1574,170 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle 2 slices even with empty packed_modules_list"
)
@pytest.mark.parametrize(
"wrapper_cls",
[ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA],
)
def test_get_and_maybe_dequant_weights_accepts_lora_wrappers(dist_init, wrapper_cls):
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
)
linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16)
lora_linear = wrapper_cls(linear)
# Should work with LoRA wrappers and return [out, in] weights.
dequant_weight = get_and_maybe_dequant_weights(lora_linear, out_dtype=torch.float16)
assert dequant_weight.shape == linear.weight.shape
@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("fully_sharded", [False, True])
def test_deepseek_fused_qkv_a_proj_lora_preserves_base_forward(
default_vllm_config, dist_init, device, stage, fully_sharded
):
if current_platform.is_cuda_alike():
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
max_loras = 8
lora_config = LoRAConfig(
max_loras=max_loras,
max_lora_rank=8,
lora_dtype=dtype,
fully_sharded_loras=fully_sharded,
)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
class OffsetDeepSeekFusedQkvAProjLinear(DeepSeekV2FusedQkvAProjLinear):
def forward(self, input_):
output, output_bias = super().forward(input_)
return output + 1, output_bias
layer = OffsetDeepSeekFusedQkvAProjLinear(
32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)
lora_layer = MergedColumnParallelLinearWithLoRA(layer)
lora_layer.create_lora_weights(max_loras, lora_config)
lora_layer.set_mapping(punica_wrapper)
id_to_index = get_random_id_to_index(1, max_loras, log=False)
active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
lora_a = [
torch.rand(8, 32, dtype=dtype, device=device),
torch.rand(8, 32, dtype=dtype, device=device),
]
lora_b = [
torch.rand(16, 8, dtype=dtype, device=device),
torch.rand(16, 8, dtype=dtype, device=device),
]
lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[1],
num_inputs=4,
input_size=(1, 32),
input_range=(0, 1),
input_type=dtype,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
lora_result = lora_layer(torch.cat(inputs))[0]
expected_results = []
for input_ in inputs:
result = layer(input_)[0]
result[:, :16] += input_ @ lora_a[0].T @ lora_b[0].T
result[:, 16:] += input_ @ lora_a[1].T @ lora_b[1].T
expected_results.append(result)
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)
merged_layer = OffsetDeepSeekFusedQkvAProjLinear(
32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
merged_layer.weight.data = layer.weight.data.clone()
merged_layer.weight.data[:16].add_(lora_b[0] @ lora_a[0])
merged_layer.weight.data[16:].add_(lora_b[1] @ lora_a[1])
merged_result = merged_layer(torch.cat(inputs))[0]
torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_replicated_lora_preserves_base_forward_for_subclasses(
default_vllm_config, dist_init, device, stage
):
if current_platform.is_cuda_alike():
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=dtype)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
class OffsetReplicatedLinear(ReplicatedLinear):
def forward(self, input_):
output, output_bias = super().forward(input_)
return output + 1, output_bias
layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)
lora_layer = ReplicatedLinearWithLoRA(layer)
lora_layer.create_lora_weights(max_loras, lora_config)
lora_layer.set_mapping(punica_wrapper)
id_to_index = get_random_id_to_index(1, max_loras, log=False)
active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
lora_a = torch.rand(8, 32, dtype=dtype, device=device)
lora_b = torch.rand(16, 8, dtype=dtype, device=device)
lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[1],
num_inputs=4,
input_size=(1, 32),
input_range=(0, 1),
input_type=dtype,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
lora_result = lora_layer(torch.cat(inputs))[0]
expected_results = []
for input_ in inputs:
result = layer(input_)[0]
result += input_ @ lora_a.T @ lora_b.T
expected_results.append(result)
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)
merged_layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
merged_layer.weight.data = layer.weight.data.clone()
merged_layer.weight.data.add_(lora_b @ lora_a)
merged_result = merged_layer(torch.cat(inputs))[0]
torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)
......@@ -13,6 +13,7 @@ from vllm.config.lora import LoRAConfig
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
)
from vllm.lora.lora_model import LoRAModel
......@@ -26,6 +27,7 @@ from vllm.lora.model_manager import (
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
from vllm.model_executor.layers.fused_moe import GateLinear
from vllm.platforms import current_platform
from .utils import create_peft_lora
......@@ -132,6 +134,135 @@ def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA)
def test_wrap_replicated_linear_subclasses(default_vllm_config, dist_init, dummy_model):
from vllm.model_executor.layers.linear import ReplicatedLinear
class CustomReplicatedLinear(ReplicatedLinear):
pass
model = dummy_model
model.add_module("custom_gate", CustomReplicatedLinear(10, 10, bias=False))
manager = LoRAModelManager(
model,
1,
1,
1,
LoRAConfig(
max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
),
torch.device(DEVICES[0]),
)
assert isinstance(
manager.model.get_submodule("custom_gate"), ReplicatedLinearWithLoRA
)
def test_wrap_gate_linear(default_vllm_config, dist_init, dummy_model):
model = dummy_model
model.add_module("router_gate", GateLinear(10, 4, bias=False))
manager = LoRAModelManager(
model,
1,
1,
1,
LoRAConfig(
max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
),
torch.device(DEVICES[0]),
)
assert isinstance(
manager.model.get_submodule("router_gate"), ReplicatedLinearWithLoRA
)
def test_skip_unsupported_matched_modules(default_vllm_config, dist_init, dummy_model):
class UnsupportedContainer(nn.Module):
def __init__(self):
super().__init__()
# This name matches a supported target suffix ("dense1"),
# but nn.Linear is not currently a LoRA-wrappable layer type.
self.dense1 = nn.Linear(10, 10, bias=False)
model = dummy_model
model.add_module("unsupported", UnsupportedContainer())
manager = LoRAModelManager(
model,
1,
1,
1,
LoRAConfig(
max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
),
torch.device(DEVICES[0]),
)
# Should not crash and should keep unsupported matched modules unchanged.
assert isinstance(manager.model.get_submodule("unsupported.dense1"), nn.Linear)
assert "unsupported.dense1" not in manager.modules
def test_target_modules_fail_closed_on_unsupported_matched_modules(
default_vllm_config, dist_init, dummy_model
):
class UnsupportedContainer(nn.Module):
def __init__(self):
super().__init__()
self.dense1 = nn.Linear(10, 10, bias=False)
model = dummy_model
model.add_module("unsupported", UnsupportedContainer())
with pytest.raises(ValueError, match="unsupported.dense1"):
LoRAModelManager(
model,
1,
1,
1,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=8,
max_loras=8,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense1"],
),
torch.device(DEVICES[0]),
)
def test_get_dummy_lora_warmup_rank_for_fully_sharded_moe():
manager = LoRAModelManager.__new__(LoRAModelManager)
manager.lora_config = LoRAConfig(
max_lora_rank=64,
max_cpu_loras=1,
max_loras=1,
lora_dtype=DEFAULT_DTYPE,
fully_sharded_loras=True,
)
class DummyModule:
def __init__(self, tp_size: int, fully_sharded: bool):
self.tp_size = tp_size
self.fully_sharded = fully_sharded
manager.modules = {
"model.layers.0.self_attn.q_proj": DummyModule(
tp_size=32,
fully_sharded=True,
),
"model.layers.0.mlp.experts": DummyModule(
tp_size=32,
fully_sharded=True,
),
}
assert manager.get_dummy_lora_warmup_rank(8) == 32
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
model = dummy_model
......@@ -795,6 +926,25 @@ def test_target_modules_none_uses_all(
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_match_packed_runtime_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device
):
"""Packed runtime modules should be selected by their adapter-visible names."""
_test_target_modules(
dummy_model_gate_up,
["gate_proj"],
device,
expected_lora=[("gate_up_proj", MergedColumnParallelLinearWithLoRA)],
expected_no_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
......
......@@ -58,3 +58,24 @@ class TestIsInTargetModules:
def test_exact_name_no_match(self):
assert not is_in_target_modules("dense3", ["dense1", "dense2"])
def test_packed_parent_matches_child_target_modules(self):
assert is_in_target_modules(
"model.layers.0.mlp.gate_up_proj",
["gate_proj", "up_proj"],
{"gate_up_proj": ["gate_proj", "up_proj"]},
)
def test_packed_child_matches_parent_target_modules(self):
assert is_in_target_modules(
"model.layers.0.mlp.gate_proj",
["gate_up_proj"],
{"gate_up_proj": ["gate_proj", "up_proj"]},
)
def test_fused_parent_matches_child_target_modules(self):
assert is_in_target_modules(
"model.layers.0.self_attn.fused_qkv_a_proj",
["q_a_proj", "kv_a_proj_with_mqa"],
{"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]},
)
......@@ -203,7 +203,16 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
return self._apply_lora_to_output(x, output)
def _apply_base_forward(self, x: torch.Tensor) -> torch.Tensor:
base_output = self.base_layer(x)
output = base_output[0] if isinstance(base_output, tuple) else base_output
return self._apply_lora_to_output(x, output)
def _apply_lora_to_output(
self, x: torch.Tensor, output: torch.Tensor
) -> torch.Tensor:
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
......
......@@ -40,11 +40,19 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
local_lora_rank = layer.lora_a_stacked[0].shape[2]
buffer_shape = (layer.n_slices, x.shape[0], local_lora_rank)
# Under torch.compile, the local-rank-1 fully-sharded path can otherwise
# get lowered to a reinterpret view with a non-canonical layout. The
# Triton shrink op mutates this buffer in place and expects the standard
# contiguous [slice, token, rank] stride contract.
buffers = torch.empty_strided(
buffer_shape,
(x.shape[0] * local_lora_rank, local_lora_rank, 1),
dtype=torch.float32,
device=x.device,
)
buffers.zero_()
shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0
......@@ -86,7 +94,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# The base_layer type is ColumnParallelLinear or
# MergedColumnParallelLinear, their weight sharding logic is
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
self.is_merged_col_linear = isinstance(base_layer, MergedColumnParallelLinear)
self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer
self.n_slices = 1
......@@ -158,7 +166,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
) -> bool:
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
return True
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)):
if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
......@@ -275,19 +283,41 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True)
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
# Effectively unsharded subclasses can safely reuse their custom
# forward() implementation before applying the LoRA delta.
if (
self.tp_size == 1
and type(self.base_layer) is not merged_cls
and type(self.base_layer).forward is not merged_cls.forward
):
return self._apply_base_forward(x)
return _mcp_apply(x, bias, self)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
decorate: bool = True,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2
)
merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2:
return False
tp_size = getattr(source_layer, "tp_size", 1)
if type(source_layer) is merged_cls:
if not decorate:
return True
return not lora_config.fully_sharded_loras or tp_size == 1
# Only support effectively unsharded subclasses here. Sharded
# subclasses may have custom communication semantics that the generic
# merged-column LoRA path does not know how to preserve.
return tp_size == 1
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
......@@ -607,7 +637,9 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
if not isinstance(
source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)
):
return False
# If packed_modules_list has 3+ items, use this class
......
......@@ -46,6 +46,12 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
return output, output_bias
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
# ReplicatedLinear subclasses such as GateLinear override forward() to
# dispatch custom kernels and/or adjust the output dtype. Apply LoRA on
# top of the actual base-layer output instead of bypassing that path.
return self._apply_base_forward(x)
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
......@@ -56,7 +62,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
return isinstance(source_layer, maybe_get_oot_by_class(ReplicatedLinear))
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
......
......@@ -437,12 +437,21 @@ class LoRAModelManager:
),
)
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
# Some matched modules can be unsupported by LoRA wrappers
# (e.g. subclasses with specialized forward behavior).
if not isinstance(new_module, BaseLayerWithLoRA):
error_msg = (
"LoRA target module "
f"{module_name} ({type(module).__name__}) matched the "
"deployment configuration but could not be wrapped by any "
"LoRA layer implementation."
)
if self.lora_config.target_modules is not None:
raise ValueError(
f"{error_msg} target_modules="
f"{sorted(self.lora_config.target_modules)}"
)
logger.warning_once("%s It will be ignored.", error_msg)
continue
self.register_module(module_name, new_module)
......@@ -578,6 +587,38 @@ class LoRAModelManager:
model.loras[module_name] = lora
return model
def get_dummy_lora_warmup_rank(self, default_rank: int) -> int:
"""Return a dummy LoRA rank compatible with wrapped modules.
Dummy LoRAs keep warmup memory low by using a small rank. Fully
sharded MoE wrappers additionally require the dummy rank to be divisible
by tensor parallel size because they shard W13 along the rank axis.
"""
if not self.lora_config.fully_sharded_loras:
return default_rank
required_multiple = 1
for module in self.modules.values():
if not getattr(module, "fully_sharded", False):
continue
required_multiple = math.lcm(required_multiple, module.tp_size)
if required_multiple == 1 or default_rank % required_multiple == 0:
return default_rank
adjusted_rank = (
(default_rank + required_multiple - 1) // required_multiple
) * required_multiple
if adjusted_rank > self.lora_config.max_lora_rank:
raise ValueError(
"Unable to choose a dummy LoRA warmup rank compatible with "
"fully sharded MoE modules: "
f"default_rank={default_rank}, "
f"required_multiple={required_multiple}, "
f"max_lora_rank={self.lora_config.max_lora_rank}"
)
return adjusted_rank
def _match_target_modules(self, module_name: str) -> bool:
"""Check if a module should have LoRA applied.
......@@ -594,7 +635,11 @@ class LoRAModelManager:
"""
if not is_supported_lora_module(module_name, self.supported_lora_modules):
return False
return is_in_target_modules(module_name, self.lora_config.target_modules)
return is_in_target_modules(
module_name,
self.lora_config.target_modules,
self.packed_modules_mapping,
)
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
"""
......
......@@ -73,7 +73,9 @@ def get_lora_id():
return _GLOBAL_LORA_ID
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
# Order matters here: more specific wrappers must be checked before generic
# merged/column-parallel wrappers in from_layer().
_all_lora_classes: tuple[type[BaseLayerWithLoRA], ...] = (
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
......@@ -90,7 +92,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
FusedMoE3DWithLoRA,
}
)
def is_moe_model(model: nn.Module) -> bool:
......@@ -258,6 +260,7 @@ def is_supported_lora_module(
def is_in_target_modules(
module_name: str,
target_modules: list[str] | None,
packed_modules_mapping: dict[str, list[str]] | None = None,
) -> bool:
"""Check if a module passes the deployment-time target_modules filter.
......@@ -268,14 +271,33 @@ def is_in_target_modules(
module_name: Full dot-separated module name.
target_modules: Optional deployment-time restriction list from
LoRAConfig.target_modules.
packed_modules_mapping: Optional model-defined mapping from packed
runtime module names to their adapter-visible submodule names
(e.g. ``{"gate_up_proj": ["gate_proj", "up_proj"]}``).
Returns:
True if the module passes the filter, False otherwise.
"""
if target_modules is None:
return True
target_module_set = set(target_modules)
module_suffix = module_name.split(".")[-1]
return module_suffix in set(target_modules)
if module_suffix in target_module_set or module_name in target_module_set:
return True
if not packed_modules_mapping:
return False
# Runtime packed parent matched by deployment-time child targets.
packed_children = packed_modules_mapping.get(module_suffix)
if packed_children and any(child in target_module_set for child in packed_children):
return True
# Adapter-visible packed child matched by deployment-time parent target.
return any(
module_suffix in children and packed_parent in target_module_set
for packed_parent, children in packed_modules_mapping.items()
)
def get_adapter_absolute_path(lora_path: str) -> str:
......
......@@ -160,7 +160,11 @@ class WorkerLoRAManager:
lora_request.lora_path,
", ".join(sorted(expected_lora_modules_lst)),
)
elif not is_in_target_modules(module_name, target_modules):
elif not is_in_target_modules(
module_name,
target_modules,
packed_modules_mapping,
):
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"deployment-time target_modules restriction [%s]."
......@@ -197,6 +201,9 @@ class WorkerLoRAManager:
self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora)
def get_dummy_lora_warmup_rank(self, default_rank: int) -> int:
return self._adapter_manager.get_dummy_lora_warmup_rank(default_rank)
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
......
......@@ -214,6 +214,19 @@ def select_unquantized_moe_backend(
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# LoRA needs Triton's unfused activation/reduction hooks. Selecting the
# backend here ensures weights stay in a LoRA-compatible layout instead of
# being permuted for a backend like FlashInfer or AITER during load.
if moe_config.is_lora_enabled:
backend = UnquantizedMoeBackend.TRITON
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
backend = UnquantizedMoeBackend.BATCHED_TRITON
return _return_or_raise(
backend,
moe_config,
activation_format,
)
runner_backend = moe_config.moe_backend
if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend)
......
......@@ -356,6 +356,12 @@ def get_and_maybe_dequant_weights(
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
# LoRA linear wrappers store quantization metadata on `base_layer`.
# Unwrap here so callers can pass either a raw linear layer or its LoRA
# wrapper without special-casing.
while hasattr(layer, "base_layer") and hasattr(layer.base_layer, "quant_method"):
layer = layer.base_layer
weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])
# Unquantized layer: just return base weights
......
......@@ -101,9 +101,12 @@ class LoRAModelRunnerMixin:
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
lora_warmup_rank = (
lora_warmup_rank: int = (
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
)
lora_warmup_rank = self.lora_manager.get_dummy_lora_warmup_rank(
lora_warmup_rank
)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment