Unverified Commit 105b8ce4 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Reduce LoRA-related static variable (#13166)

parent 2cb8c154
......@@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.platforms import current_platform
......@@ -98,9 +99,13 @@ def dist_init_torch_only():
backend=backend)
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
pass
@pytest.fixture
def dummy_model() -> nn.Module:
model = nn.Sequential(
model = DummyLoRAModel(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
......@@ -121,12 +126,13 @@ def dummy_model() -> nn.Module:
("sampler", Sampler())
]))
model.config = MagicMock()
model.embedding_modules = {"lm_head": "lm_head"}
return model
@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
model = nn.Sequential(
model = DummyLoRAModel(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
......@@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module:
("sampler", Sampler())
]))
model.config = MagicMock()
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model.embedding_modules = {"lm_head": "lm_head"}
return model
......
......@@ -12,6 +12,12 @@ from vllm.model_executor.models.utils import WeightsMapper
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
BAICHUAN_LORA_MODULES = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
@pytest.mark.parametrize("lora_name", lora_lst)
......@@ -22,12 +28,11 @@ def test_load_checkpoints(
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
......@@ -90,12 +95,12 @@ def test_load_checkpoints(
def test_lora_weights_mapping(baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
......
......@@ -11,17 +11,20 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
LLAMA_LORA_MODULES = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
......
......@@ -19,7 +19,6 @@ from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
EMBEDDING_MODULES = {
......@@ -114,19 +113,16 @@ def create_packed_lora(
def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
model.supported_lora_modules = ["dense1", "layer1.dense2"]
model.packed_modules_mapping = {}
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0]))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense2"),
RowParallelLinearWithLoRA)
......@@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model):
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
......@@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.device == device
assert manager.punica_wrapper.device == device
assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [
"dense1",
"dense2",
"lm_head",
"output",
]
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
......@@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"],
device=device)
......@@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model_lora = create_packed_lora(
1,
model,
......
......@@ -26,6 +26,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
......@@ -332,15 +333,15 @@ class LoRAModelManager(AdapterModelManager):
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model)
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f"{self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
......@@ -756,7 +757,7 @@ def create_lora_manager(
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "supported_lora_modules"):
if not hasattr(model, "packed_modules_mapping"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
......
......@@ -29,6 +29,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm.model_executor.layers.linear import LinearBase
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
......@@ -68,6 +69,14 @@ def from_layer(layer: nn.Module,
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
# The Case for HFCompatibleLinear
if (hasattr(layer, "get_lora_class")
and layer.__class__.__name__ == "HFCompatibleLinear"):
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
......@@ -170,6 +179,23 @@ def is_regex_target_modules(load_modules: Union[str, List[str]],
return False
def get_supported_lora_modules(model: nn.Module) -> List[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: Set[str] = set()
# step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
supported_lora_modules.add(name.split(".")[-1])
# step 2: get the embedding modules if the model's mbedding_modules
# is not empty.
if model.embedding_modules:
for name in model.embedding_modules:
supported_lora_modules.add(name)
return list(supported_lora_modules)
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
......
......@@ -84,9 +84,10 @@ class WorkerLoRAManager(AbstractWorkerManager):
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
model = self._adapter_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
supported_lora_modules = (
self._adapter_manager.supported_lora_modules)
packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping)
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
......@@ -107,6 +108,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
......
......@@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
......
......@@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
......
......@@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
......
......@@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"out_proj",
"gate_up_proj",
"c_proj",
"wte",
"lm_head",
]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
......
......@@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......
......@@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......
......@@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
"dense_h_to_4h": ["dense_h_to_4h"],
"merged_proj": ["gate_proj", "dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
# vision
"fc1",
"fc2",
"merged_proj",
"linear_proj"
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
......
......@@ -261,15 +261,12 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
# LoRA specific attributes
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......
......@@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
......
......@@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"layer",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
......
......@@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -118,11 +118,11 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules: ClassVar[Dict[str, str]] = {}
embedding_padding_modules: ClassVar[List[str]] = []
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
# We can't use runtime_checkable with ClassVar for issubclass checks
......@@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
......@@ -155,7 +154,6 @@ def supports_lora(
if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment