"vscode:/vscode.git/clone" did not exist on "5e973209aaf5fa15459555eaa42bcd20ea63aa0d"
Unverified Commit 5be4e52b authored by B-201's avatar B-201 Committed by GitHub
Browse files

[Model][LoRA]LoRA support added for glm-4v (#10418)


Signed-off-by: default avatarB-201 <Joy25810@foxmail.com>
parent 01aae1cc
......@@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
......@@ -574,25 +575,8 @@ class ChatGLMModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsMultiModal):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP,
SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -692,3 +676,79 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
class ChatGLM(ChatGLMBaseModel):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
class ChatGLMV(ChatGLMBaseModel):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
"merged_proj": ["gate_proj", "dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
# vision
"fc1",
"fc2",
"merged_proj",
"linear_proj"
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.encoder",
connector="transformer.vision.linear_proj",
tower_model="transformer.vision.transformer")
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(
cls,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM
else:
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment