Unverified Commit 3717a4dd authored by Bhoomit's avatar Bhoomit Committed by GitHub
Browse files

[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules (#34984)


Signed-off-by: default avatarBhoomit Vasani <bhoomit.2010@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent ecfcdd2c
...@@ -389,3 +389,17 @@ vllm serve model --enable-lora --max-lora-rank 64 ...@@ -389,3 +389,17 @@ vllm serve model --enable-lora --max-lora-rank 64
# Bad: unnecessarily high, wastes memory # Bad: unnecessarily high, wastes memory
vllm serve model --enable-lora --max-lora-rank 256 vllm serve model --enable-lora --max-lora-rank 256
``` ```
### Restricting LoRA to Specific Modules
The `--lora-target-modules` parameter allows you to restrict which model modules have LoRA applied at deployment time. This is useful for performance tuning when you only need LoRA on specific layers:
```bash
# Apply LoRA only to output projection layers
vllm serve model --enable-lora --lora-target-modules o_proj
# Apply LoRA to multiple specific modules
vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj down_proj
```
When `--lora-target-modules` is not specified, LoRA will be applied to all supported modules in the model. This parameter accepts module suffixes (the last component of the module name), such as `o_proj`, `qkv_proj`, `gate_proj`, etc.
...@@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises): ...@@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
else: else:
with pytest.raises(raises): with pytest.raises(raises):
vllm_parser.parse_args(args=args) vllm_parser.parse_args(args=args)
### Tests for LoRA target modules parsing
def test_lora_target_modules_single(serve_parser):
"""Test parsing single lora-target-modules argument"""
args = serve_parser.parse_args(
args=["--enable-lora", "--lora-target-modules", "o_proj"]
)
assert args.lora_target_modules == ["o_proj"]
def test_lora_target_modules_multiple(serve_parser):
"""Test parsing multiple lora-target-modules arguments"""
args = serve_parser.parse_args(
args=[
"--enable-lora",
"--lora-target-modules",
"o_proj",
"qkv_proj",
"down_proj",
]
)
assert args.lora_target_modules == ["o_proj", "qkv_proj", "down_proj"]
def test_lora_target_modules_default_none(serve_parser):
"""Test that lora-target-modules defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.lora_target_modules is None
...@@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic ...@@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
torch.testing.assert_close( torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
) )
def _test_target_modules(
model,
target_modules: list[str] | None,
device: str,
expected_lora: list[tuple[str, type]],
expected_no_lora: list[tuple[str, type]],
):
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=2,
max_loras=2,
lora_dtype=DEFAULT_DTYPE,
target_modules=target_modules,
),
device=device,
)
for module_path, lora_cls in expected_lora:
assert isinstance(model.get_submodule(module_path), lora_cls)
for module_path, lora_cls in expected_no_lora:
assert not isinstance(model.get_submodule(module_path), lora_cls)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
"""Test that target_modules config restricts which modules get LoRA applied."""
_test_target_modules(
dummy_model,
["dense1"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
],
expected_no_lora=[
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
"""Test that multiple target_modules work correctly."""
_test_target_modules(
dummy_model,
["dense1", "dense2"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
default_vllm_config, dist_init, dummy_model, device
):
"""Test that target_modules=None uses all supported modules."""
_test_target_modules(
dummy_model,
None,
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
not in the model's supported LoRA target modules."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
# Patch from_local_checkpoint to inject an unsupported module
original_from_checkpoint = LoRAModel.from_local_checkpoint
def patched_from_checkpoint(*args, **kwargs):
lora = original_from_checkpoint(*args, **kwargs)
lora.loras["unsupported_module"] = LoRALayerWeights(
module_name="unsupported_module",
rank=8,
lora_alpha=16,
lora_a=torch.randn(8, 10),
lora_b=torch.randn(10, 8),
)
return lora
lora_request = LoRARequest("test", 1, dummy_lora_files)
with (
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
patch.object(wm_module.logger, "warning_once") as mock_warning,
):
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
found = any("unsupported_module" in str(call) for call in warning_args)
assert found, (
f"Expected warning about 'unsupported_module', got: {warning_args}"
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
excluded by the deployment-time target_modules restriction."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
# Restrict to only dense2 — adapter has dense1 which will be excluded
lora_config = LoRAConfig(
max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense2"],
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
lora_request = LoRARequest("test", 1, dummy_lora_files)
with patch.object(wm_module.logger, "warning_once") as mock_warning:
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
# dense1 is supported by the model but excluded by target_modules
found = any("target_modules" in str(call) for call in warning_args)
assert found, (
f"Expected warning about target_modules restriction, got: {warning_args}"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.utils import is_in_target_modules, is_supported_lora_module
class TestIsSupportedLoraModule:
"""Tests for is_supported_lora_module (model-definition check)."""
def test_suffix_match(self):
assert is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_no_match(self):
assert not is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_exact_match(self):
assert is_supported_lora_module("o_proj", ["o_proj"])
def test_regex_suffix_matching(self):
"""Regex anchors to end — partial suffix should not match."""
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"])
def test_empty_supported_modules(self):
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", [])
def test_multiple_supported_modules(self):
supported = ["q_proj", "k_proj", "v_proj", "o_proj"]
assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported)
assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported)
class TestIsInTargetModules:
"""Tests for is_in_target_modules (deployment-time filter)."""
def test_none_allows_all(self):
assert is_in_target_modules("model.layers.0.self_attn.o_proj", None)
def test_suffix_in_target(self):
assert is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_suffix_not_in_target(self):
assert not is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_empty_target_modules(self):
assert not is_in_target_modules("model.layers.0.self_attn.o_proj", [])
def test_exact_name_match(self):
assert is_in_target_modules("dense1", ["dense1", "dense2"])
def test_exact_name_no_match(self):
assert not is_in_target_modules("dense3", ["dense1", "dense2"])
...@@ -43,6 +43,10 @@ class LoRAConfig: ...@@ -43,6 +43,10 @@ class LoRAConfig:
`max_loras`.""" `max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto" lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype.""" """Data type for LoRA. If auto, will default to base model dtype."""
target_modules: list[str] | None = None
"""Restrict LoRA to specific module suffixes (e.g., ["o_proj", "qkv_proj"]).
If None, all supported LoRA modules are used. This allows deployment-time
control over which modules have LoRA applied, useful for performance tuning."""
default_mm_loras: dict[str, str] | None = None default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field """Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a is only applicable to multimodal models and should be leveraged when a
...@@ -84,6 +88,10 @@ class LoRAConfig: ...@@ -84,6 +88,10 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras) factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype) factors.append(self.lora_dtype)
factors.append(self.enable_tower_connector_lora) factors.append(self.enable_tower_connector_lora)
# target_modules affects which modules get LoRA applied
factors.append(
tuple(sorted(self.target_modules)) if self.target_modules else None
)
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
......
...@@ -506,6 +506,7 @@ class EngineArgs: ...@@ -506,6 +506,7 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
lora_target_modules: list[str] | None = LoRAConfig.target_modules
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora specialize_active_lora: bool = LoRAConfig.specialize_active_lora
...@@ -1107,6 +1108,9 @@ class EngineArgs: ...@@ -1107,6 +1108,9 @@ class EngineArgs:
lora_group.add_argument( lora_group.add_argument(
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
) )
lora_group.add_argument(
"--lora-target-modules", **lora_kwargs["target_modules"]
)
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument( lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"] "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
...@@ -1800,6 +1804,7 @@ class EngineArgs: ...@@ -1800,6 +1804,7 @@ class EngineArgs:
default_mm_loras=self.default_mm_loras, default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
target_modules=self.lora_target_modules,
enable_tower_connector_lora=self.enable_tower_connector_lora, enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora, specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras max_cpu_loras=self.max_cpu_loras
......
...@@ -5,7 +5,6 @@ import math ...@@ -5,7 +5,6 @@ import math
from collections.abc import Callable from collections.abc import Callable
from typing import TypeVar from typing import TypeVar
import regex as re
import torch import torch
from torch import nn from torch import nn
...@@ -25,7 +24,9 @@ from vllm.lora.utils import ( ...@@ -25,7 +24,9 @@ from vllm.lora.utils import (
from_layer, from_layer,
from_layer_logits_processor, from_layer_logits_processor,
get_supported_lora_modules, get_supported_lora_modules,
is_in_target_modules,
is_moe_model, is_moe_model,
is_supported_lora_module,
process_packed_modules_mapping, process_packed_modules_mapping,
replace_submodule, replace_submodule,
) )
...@@ -541,14 +542,23 @@ class LoRAModelManager: ...@@ -541,14 +542,23 @@ class LoRAModelManager:
model.loras[module_name] = lora model.loras[module_name] = lora
return model return model
def _match_target_modules(self, module_name: str): def _match_target_modules(self, module_name: str) -> bool:
return any( """Check if a module should have LoRA applied.
re.match(
r".*\.{target_module}$".format(target_module=target_module), module_name This method first checks if the module is in vLLM's supported LoRA
) modules, then applies deployment-time restrictions based on
or target_module == module_name LoRAConfig.target_modules.
for target_module in self.supported_lora_modules
) Args:
module_name: Full dot-separated module name (e.g.,
"model.layers.0.self_attn.o_proj")
Returns:
True if LoRA should be applied to this module, False otherwise.
"""
if not is_supported_lora_module(module_name, self.supported_lora_modules):
return False
return is_in_target_modules(module_name, self.lora_config.target_modules)
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
""" """
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import huggingface_hub import huggingface_hub
import regex as re
from huggingface_hub.utils import HfHubHTTPError, HFValidationError from huggingface_hub.utils import HfHubHTTPError, HFValidationError
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -226,6 +227,57 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: ...@@ -226,6 +227,57 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
return list(supported_lora_modules) return list(supported_lora_modules)
def is_supported_lora_module(
module_name: str,
supported_lora_modules: list[str],
) -> bool:
"""Check if a module is in the model's supported LoRA modules.
Uses regex suffix matching against the model-defined supported modules
list (e.g., matching "model.layers.0.self_attn.o_proj" against
"o_proj").
Args:
module_name: Full dot-separated module name.
supported_lora_modules: List of module suffixes supported by the
model.
Returns:
True if the module is supported, False otherwise.
"""
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name,
)
or target_module == module_name
for target_module in supported_lora_modules
)
def is_in_target_modules(
module_name: str,
target_modules: list[str] | None,
) -> bool:
"""Check if a module passes the deployment-time target_modules filter.
When target_modules is None (no restriction), all modules pass.
Otherwise, the module's suffix must be in the target_modules list.
Args:
module_name: Full dot-separated module name.
target_modules: Optional deployment-time restriction list from
LoRAConfig.target_modules.
Returns:
True if the module passes the filter, False otherwise.
"""
if target_modules is None:
return True
module_suffix = module_name.split(".")[-1]
return module_suffix in set(target_modules)
def get_adapter_absolute_path(lora_path: str) -> str: def get_adapter_absolute_path(lora_path: str) -> str:
""" """
Resolves the given lora_path to an absolute local path. Resolves the given lora_path to an absolute local path.
......
...@@ -17,7 +17,11 @@ from vllm.lora.model_manager import ( ...@@ -17,7 +17,11 @@ from vllm.lora.model_manager import (
) )
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import (
get_adapter_absolute_path,
is_in_target_modules,
is_supported_lora_module,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -142,6 +146,29 @@ class WorkerLoRAManager: ...@@ -142,6 +146,29 @@ class WorkerLoRAManager:
skip_prefixes=lora_skip_prefixes, skip_prefixes=lora_skip_prefixes,
) )
# Warn about adapter modules that will be ignored.
target_modules = self.lora_config.target_modules
for module_name in lora.loras:
if not is_supported_lora_module(module_name, supported_lora_modules):
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"model's supported LoRA target modules [%s]. "
"These parameters will be ignored, which may "
"cause abnormal model behavior.",
module_name,
lora_request.lora_path,
", ".join(sorted(supported_lora_modules)),
)
elif not is_in_target_modules(module_name, target_modules):
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"deployment-time target_modules restriction [%s]."
" These parameters will be ignored.",
module_name,
lora_request.lora_path,
", ".join(sorted(target_modules)),
)
except FileNotFoundError as e: except FileNotFoundError as e:
# FileNotFoundError should be raised if both # FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in # - No adapter found to download from huggingface (or in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment