Unverified Commit 105b8ce4 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Reduce LoRA-related static variable (#13166)

parent 2cb8c154
...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -98,9 +99,13 @@ def dist_init_torch_only(): ...@@ -98,9 +99,13 @@ def dist_init_torch_only():
backend=backend) backend=backend)
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
pass
@pytest.fixture @pytest.fixture
def dummy_model() -> nn.Module: def dummy_model() -> nn.Module:
model = nn.Sequential( model = DummyLoRAModel(
OrderedDict([ OrderedDict([
("dense1", ColumnParallelLinear(764, 100)), ("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)), ("dense2", RowParallelLinear(100, 50)),
...@@ -121,12 +126,13 @@ def dummy_model() -> nn.Module: ...@@ -121,12 +126,13 @@ def dummy_model() -> nn.Module:
("sampler", Sampler()) ("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
model.embedding_modules = {"lm_head": "lm_head"}
return model return model
@pytest.fixture @pytest.fixture
def dummy_model_gate_up() -> nn.Module: def dummy_model_gate_up() -> nn.Module:
model = nn.Sequential( model = DummyLoRAModel(
OrderedDict([ OrderedDict([
("dense1", ColumnParallelLinear(764, 100)), ("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)), ("dense2", RowParallelLinear(100, 50)),
...@@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module: ...@@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module:
("sampler", Sampler()) ("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model.embedding_modules = {"lm_head": "lm_head"}
return model return model
......
...@@ -12,6 +12,12 @@ from vllm.model_executor.models.utils import WeightsMapper ...@@ -12,6 +12,12 @@ from vllm.model_executor.models.utils import WeightsMapper
lora_lst = [ lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
] ]
BAICHUAN_LORA_MODULES = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
@pytest.mark.parametrize("lora_name", lora_lst) @pytest.mark.parametrize("lora_name", lora_lst)
...@@ -22,12 +28,11 @@ def test_load_checkpoints( ...@@ -22,12 +28,11 @@ def test_load_checkpoints(
baichuan_regex_lora_files, baichuan_regex_lora_files,
chatglm3_lora_files, chatglm3_lora_files,
): ):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
else: else:
...@@ -90,12 +95,12 @@ def test_load_checkpoints( ...@@ -90,12 +95,12 @@ def test_load_checkpoints(
def test_lora_weights_mapping(baichuan_lora_files): def test_lora_weights_mapping(baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
else: else:
......
...@@ -11,17 +11,20 @@ from vllm.model_executor.models.llama import LlamaForCausalLM ...@@ -11,17 +11,20 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
# Provide absolute path and huggingface lora ids # Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
LLAMA_LORA_MODULES = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name) @pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request): def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name) lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
else: else:
......
...@@ -19,7 +19,6 @@ from vllm.lora.peft_helper import PEFTHelper ...@@ -19,7 +19,6 @@ from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager) WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform from vllm.platforms import current_platform
EMBEDDING_MODULES = { EMBEDDING_MODULES = {
...@@ -114,19 +113,16 @@ def create_packed_lora( ...@@ -114,19 +113,16 @@ def create_packed_lora(
def test_replace_submodules(dist_init, dummy_model): def test_replace_submodules(dist_init, dummy_model):
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "layer1.dense2"]
model.packed_modules_mapping = {}
manager = LoRAModelManager( manager = LoRAModelManager(
model, 1, 1, 1, model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0])) torch.device(DEVICES[0]))
model = manager.model model = manager.model
assert isinstance(model.get_submodule("dense1"), assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA) ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense1"), assert isinstance(model.get_submodule("layer1.dense1"),
ColumnParallelLinearWithLoRA) ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("dense2"), RowParallelLinear) assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense2"), assert isinstance(model.get_submodule("layer1.dense2"),
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
...@@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model): ...@@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device): def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"], model, ["layer1.dense1", "dense2", "lm_head"],
device=device) device=device)
...@@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device): ...@@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.device == device assert manager.device == device
assert manager.punica_wrapper.device == device assert manager.punica_wrapper.device == device
assert hasattr(manager, "supported_lora_modules")
assert sorted(manager.supported_lora_modules) == [
"dense1",
"dense2",
"lm_head",
"output",
]
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"], model, ["layer1.dense1", "dense2", "lm_head"],
device=device) device=device)
...@@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): ...@@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is # This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager # tested in test_lora_model_manager
model = dummy_model model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model_lora1 = create_lora(1,
model, ["layer1.dense1", "dense2", "lm_head"], model, ["layer1.dense1", "dense2", "lm_head"],
device=device) device=device)
...@@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, ...@@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device): def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model_lora = create_packed_lora( model_lora = create_packed_lora(
1, 1,
model, model,
......
...@@ -26,6 +26,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights ...@@ -26,6 +26,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
...@@ -332,15 +333,15 @@ class LoRAModelManager(AdapterModelManager): ...@@ -332,15 +333,15 @@ class LoRAModelManager(AdapterModelManager):
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model) super().__init__(model)
if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = get_supported_lora_modules(self.model)
self.supported_lora_modules = copy.deepcopy( assert self.supported_lora_modules, "No supported LoRA modules found in"
self.model.supported_lora_modules) f"{self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors: if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation # We need to replace rotary emb layer to do batch computation
# for long lora. # for long lora.
self.supported_lora_modules.append("rotary_emb") self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy( self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping) self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model # Used to indicate whether the model is a multimodal model
self.supports_mm: bool = ( self.supports_mm: bool = (
supports_multimodal(self.model) supports_multimodal(self.model)
...@@ -756,7 +757,7 @@ def create_lora_manager( ...@@ -756,7 +757,7 @@ def create_lora_manager(
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager: **kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model.""" """Create a LoRA adapter for a given model."""
if not hasattr(model, "supported_lora_modules"): if not hasattr(model, "packed_modules_mapping"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.") raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls( lora_manager = lora_manager_cls(
model=model, model=model,
......
...@@ -29,6 +29,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ...@@ -29,6 +29,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
ReplicatedLinearWithLoRA, ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
from vllm.model_executor.layers.linear import LinearBase
# yapf: enable # yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -68,6 +69,14 @@ def from_layer(layer: nn.Module, ...@@ -68,6 +69,14 @@ def from_layer(layer: nn.Module,
ret = lora_cls(layer) ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret
# The Case for HFCompatibleLinear
if (hasattr(layer, "get_lora_class")
and layer.__class__.__name__ == "HFCompatibleLinear"):
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer return layer
...@@ -170,6 +179,23 @@ def is_regex_target_modules(load_modules: Union[str, List[str]], ...@@ -170,6 +179,23 @@ def is_regex_target_modules(load_modules: Union[str, List[str]],
return False return False
def get_supported_lora_modules(model: nn.Module) -> List[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: Set[str] = set()
# step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
supported_lora_modules.add(name.split(".")[-1])
# step 2: get the embedding modules if the model's mbedding_modules
# is not empty.
if model.embedding_modules:
for name in model.embedding_modules:
supported_lora_modules.add(name)
return list(supported_lora_modules)
def get_adapter_absolute_path(lora_path: str) -> str: def get_adapter_absolute_path(lora_path: str) -> str:
""" """
Resolves the given lora_path to an absolute local path. Resolves the given lora_path to an absolute local path.
......
...@@ -84,9 +84,10 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -84,9 +84,10 @@ class WorkerLoRAManager(AbstractWorkerManager):
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try: try:
model = self._adapter_manager.model supported_lora_modules = (
supported_lora_modules = model.supported_lora_modules self._adapter_manager.supported_lora_modules)
packed_modules_mapping = model.packed_modules_mapping packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping)
expected_lora_modules: List[str] = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
...@@ -107,6 +108,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -107,6 +108,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights. # to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper") if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None): and model.hf_to_vllm_mapper is not None):
......
...@@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj", "up_proj",
], ],
} }
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
......
...@@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
......
...@@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): ...@@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
} }
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
], ],
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"} embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"out_proj",
"gate_up_proj",
"c_proj",
"wte",
"lm_head",
]
embedding_modules = { embedding_modules = {
"wte": "input_embeddings", "wte": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
......
...@@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
], ],
} }
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
], ],
} }
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
......
...@@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
"dense_h_to_4h": ["dense_h_to_4h"], "dense_h_to_4h": ["dense_h_to_4h"],
"merged_proj": ["gate_proj", "dense_h_to_4h"] "merged_proj": ["gate_proj", "dense_h_to_4h"]
} }
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
# vision
"fc1",
"fc2",
"merged_proj",
"linear_proj"
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
...@@ -261,15 +261,12 @@ class GPTBigCodeModel(nn.Module): ...@@ -261,15 +261,12 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] # LoRA specific attributes
embedding_modules = { embedding_modules = {
"wte": "input_embeddings", "wte": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
} }
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
......
...@@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"layer",
]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
......
...@@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"up_proj", "up_proj",
], ],
} }
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -118,11 +118,11 @@ class SupportsLoRA(Protocol): ...@@ -118,11 +118,11 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
MRO of your model class. MRO of your model class.
""" """
# The `embedding_module` and `embedding_padding_modules`
packed_modules_mapping: ClassVar[Dict[str, List[str]]] # are empty by default.
supported_lora_modules: ClassVar[List[str]] embedding_modules: ClassVar[Dict[str, str]] = {}
embedding_modules: ClassVar[Dict[str, str]] embedding_padding_modules: ClassVar[List[str]] = []
embedding_padding_modules: ClassVar[List[str]] packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
# We can't use runtime_checkable with ClassVar for issubclass checks # We can't use runtime_checkable with ClassVar for issubclass checks
...@@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol): ...@@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol):
supports_lora: Literal[True] supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]] packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str] embedding_modules: Dict[str, str]
embedding_padding_modules: List[str] embedding_padding_modules: List[str]
...@@ -155,7 +154,6 @@ def supports_lora( ...@@ -155,7 +154,6 @@ def supports_lora(
if not result: if not result:
lora_attrs = ( lora_attrs = (
"packed_modules_mapping", "packed_modules_mapping",
"supported_lora_modules",
"embedding_modules", "embedding_modules",
"embedding_padding_modules", "embedding_padding_modules",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment