Unverified Commit 3717a4dd authored by Bhoomit's avatar Bhoomit Committed by GitHub
Browse files

[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules (#34984)


Signed-off-by: default avatarBhoomit Vasani <bhoomit.2010@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent ecfcdd2c
......@@ -389,3 +389,17 @@ vllm serve model --enable-lora --max-lora-rank 64
# Bad: unnecessarily high, wastes memory
vllm serve model --enable-lora --max-lora-rank 256
```
### Restricting LoRA to Specific Modules
The `--lora-target-modules` parameter allows you to restrict which model modules have LoRA applied at deployment time. This is useful for performance tuning when you only need LoRA on specific layers:
```bash
# Apply LoRA only to output projection layers
vllm serve model --enable-lora --lora-target-modules o_proj
# Apply LoRA to multiple specific modules
vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj down_proj
```
When `--lora-target-modules` is not specified, LoRA will be applied to all supported modules in the model. This parameter accepts module suffixes (the last component of the module name), such as `o_proj`, `qkv_proj`, `gate_proj`, etc.
......@@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
else:
with pytest.raises(raises):
vllm_parser.parse_args(args=args)
### Tests for LoRA target modules parsing
def test_lora_target_modules_single(serve_parser):
"""Test parsing single lora-target-modules argument"""
args = serve_parser.parse_args(
args=["--enable-lora", "--lora-target-modules", "o_proj"]
)
assert args.lora_target_modules == ["o_proj"]
def test_lora_target_modules_multiple(serve_parser):
"""Test parsing multiple lora-target-modules arguments"""
args = serve_parser.parse_args(
args=[
"--enable-lora",
"--lora-target-modules",
"o_proj",
"qkv_proj",
"down_proj",
]
)
assert args.lora_target_modules == ["o_proj", "qkv_proj", "down_proj"]
def test_lora_target_modules_default_none(serve_parser):
"""Test that lora-target-modules defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.lora_target_modules is None
......@@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
)
def _test_target_modules(
model,
target_modules: list[str] | None,
device: str,
expected_lora: list[tuple[str, type]],
expected_no_lora: list[tuple[str, type]],
):
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=2,
max_loras=2,
lora_dtype=DEFAULT_DTYPE,
target_modules=target_modules,
),
device=device,
)
for module_path, lora_cls in expected_lora:
assert isinstance(model.get_submodule(module_path), lora_cls)
for module_path, lora_cls in expected_no_lora:
assert not isinstance(model.get_submodule(module_path), lora_cls)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
"""Test that target_modules config restricts which modules get LoRA applied."""
_test_target_modules(
dummy_model,
["dense1"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
],
expected_no_lora=[
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
"""Test that multiple target_modules work correctly."""
_test_target_modules(
dummy_model,
["dense1", "dense2"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
default_vllm_config, dist_init, dummy_model, device
):
"""Test that target_modules=None uses all supported modules."""
_test_target_modules(
dummy_model,
None,
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
not in the model's supported LoRA target modules."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
# Patch from_local_checkpoint to inject an unsupported module
original_from_checkpoint = LoRAModel.from_local_checkpoint
def patched_from_checkpoint(*args, **kwargs):
lora = original_from_checkpoint(*args, **kwargs)
lora.loras["unsupported_module"] = LoRALayerWeights(
module_name="unsupported_module",
rank=8,
lora_alpha=16,
lora_a=torch.randn(8, 10),
lora_b=torch.randn(10, 8),
)
return lora
lora_request = LoRARequest("test", 1, dummy_lora_files)
with (
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
patch.object(wm_module.logger, "warning_once") as mock_warning,
):
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
found = any("unsupported_module" in str(call) for call in warning_args)
assert found, (
f"Expected warning about 'unsupported_module', got: {warning_args}"
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
excluded by the deployment-time target_modules restriction."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
# Restrict to only dense2 — adapter has dense1 which will be excluded
lora_config = LoRAConfig(
max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense2"],
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
lora_request = LoRARequest("test", 1, dummy_lora_files)
with patch.object(wm_module.logger, "warning_once") as mock_warning:
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
# dense1 is supported by the model but excluded by target_modules
found = any("target_modules" in str(call) for call in warning_args)
assert found, (
f"Expected warning about target_modules restriction, got: {warning_args}"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.utils import is_in_target_modules, is_supported_lora_module
class TestIsSupportedLoraModule:
"""Tests for is_supported_lora_module (model-definition check)."""
def test_suffix_match(self):
assert is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_no_match(self):
assert not is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_exact_match(self):
assert is_supported_lora_module("o_proj", ["o_proj"])
def test_regex_suffix_matching(self):
"""Regex anchors to end — partial suffix should not match."""
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"])
def test_empty_supported_modules(self):
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", [])
def test_multiple_supported_modules(self):
supported = ["q_proj", "k_proj", "v_proj", "o_proj"]
assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported)
assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported)
class TestIsInTargetModules:
"""Tests for is_in_target_modules (deployment-time filter)."""
def test_none_allows_all(self):
assert is_in_target_modules("model.layers.0.self_attn.o_proj", None)
def test_suffix_in_target(self):
assert is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_suffix_not_in_target(self):
assert not is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_empty_target_modules(self):
assert not is_in_target_modules("model.layers.0.self_attn.o_proj", [])
def test_exact_name_match(self):
assert is_in_target_modules("dense1", ["dense1", "dense2"])
def test_exact_name_no_match(self):
assert not is_in_target_modules("dense3", ["dense1", "dense2"])
......@@ -43,6 +43,10 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
target_modules: list[str] | None = None
"""Restrict LoRA to specific module suffixes (e.g., ["o_proj", "qkv_proj"]).
If None, all supported LoRA modules are used. This allows deployment-time
control over which modules have LoRA applied, useful for performance tuning."""
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
......@@ -84,6 +88,10 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.enable_tower_connector_lora)
# target_modules affects which modules get LoRA applied
factors.append(
tuple(sorted(self.target_modules)) if self.target_modules else None
)
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
......
......@@ -506,6 +506,7 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
lora_target_modules: list[str] | None = LoRAConfig.target_modules
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora
......@@ -1107,6 +1108,9 @@ class EngineArgs:
lora_group.add_argument(
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
)
lora_group.add_argument(
"--lora-target-modules", **lora_kwargs["target_modules"]
)
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
......@@ -1800,6 +1804,7 @@ class EngineArgs:
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype,
target_modules=self.lora_target_modules,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
......
......@@ -5,7 +5,6 @@ import math
from collections.abc import Callable
from typing import TypeVar
import regex as re
import torch
from torch import nn
......@@ -25,7 +24,9 @@ from vllm.lora.utils import (
from_layer,
from_layer_logits_processor,
get_supported_lora_modules,
is_in_target_modules,
is_moe_model,
is_supported_lora_module,
process_packed_modules_mapping,
replace_submodule,
)
......@@ -541,14 +542,23 @@ class LoRAModelManager:
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module), module_name
)
or target_module == module_name
for target_module in self.supported_lora_modules
)
def _match_target_modules(self, module_name: str) -> bool:
"""Check if a module should have LoRA applied.
This method first checks if the module is in vLLM's supported LoRA
modules, then applies deployment-time restrictions based on
LoRAConfig.target_modules.
Args:
module_name: Full dot-separated module name (e.g.,
"model.layers.0.self_attn.o_proj")
Returns:
True if LoRA should be applied to this module, False otherwise.
"""
if not is_supported_lora_module(module_name, self.supported_lora_modules):
return False
return is_in_target_modules(module_name, self.lora_config.target_modules)
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
"""
......
......@@ -5,6 +5,7 @@ import os
from typing import TYPE_CHECKING
import huggingface_hub
import regex as re
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
from torch import nn
from transformers import PretrainedConfig
......@@ -226,6 +227,57 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
return list(supported_lora_modules)
def is_supported_lora_module(
module_name: str,
supported_lora_modules: list[str],
) -> bool:
"""Check if a module is in the model's supported LoRA modules.
Uses regex suffix matching against the model-defined supported modules
list (e.g., matching "model.layers.0.self_attn.o_proj" against
"o_proj").
Args:
module_name: Full dot-separated module name.
supported_lora_modules: List of module suffixes supported by the
model.
Returns:
True if the module is supported, False otherwise.
"""
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name,
)
or target_module == module_name
for target_module in supported_lora_modules
)
def is_in_target_modules(
module_name: str,
target_modules: list[str] | None,
) -> bool:
"""Check if a module passes the deployment-time target_modules filter.
When target_modules is None (no restriction), all modules pass.
Otherwise, the module's suffix must be in the target_modules list.
Args:
module_name: Full dot-separated module name.
target_modules: Optional deployment-time restriction list from
LoRAConfig.target_modules.
Returns:
True if the module passes the filter, False otherwise.
"""
if target_modules is None:
return True
module_suffix = module_name.split(".")[-1]
return module_suffix in set(target_modules)
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
......
......@@ -17,7 +17,11 @@ from vllm.lora.model_manager import (
)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.lora.utils import (
get_adapter_absolute_path,
is_in_target_modules,
is_supported_lora_module,
)
logger = init_logger(__name__)
......@@ -142,6 +146,29 @@ class WorkerLoRAManager:
skip_prefixes=lora_skip_prefixes,
)
# Warn about adapter modules that will be ignored.
target_modules = self.lora_config.target_modules
for module_name in lora.loras:
if not is_supported_lora_module(module_name, supported_lora_modules):
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"model's supported LoRA target modules [%s]. "
"These parameters will be ignored, which may "
"cause abnormal model behavior.",
module_name,
lora_request.lora_path,
", ".join(sorted(supported_lora_modules)),
)
elif not is_in_target_modules(module_name, target_modules):
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"deployment-time target_modules restriction [%s]."
" These parameters will be ignored.",
module_name,
lora_request.lora_path,
", ".join(sorted(target_modules)),
)
except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment