Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5be4e52b
Unverified
Commit
5be4e52b
authored
Nov 18, 2024
by
B-201
Committed by
GitHub
Nov 18, 2024
Browse files
[Model][LoRA]LoRA support added for glm-4v (#10418)
Signed-off-by:
B-201
<
Joy25810@foxmail.com
>
parent
01aae1cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
19 deletions
+79
-19
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+79
-19
No files found.
vllm/model_executor/models/chatglm.py
View file @
5be4e52b
...
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.glm4_vision_encoder
import
EVA2CLIPModel
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalData
,
MultiModalKwargs
...
...
@@ -574,25 +575,8 @@ class ChatGLMModel(nn.Module):
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
mm_input_mapper_for_glmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_glmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_glmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_glmv
)
class
ChatGLMForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
class
ChatGLMBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"query_key_value"
,
"dense"
,
"dense_h_to_4h"
,
"dense_4h_to_h"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -692,3 +676,79 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader
(
param
,
combined_weight
)
loaded_params
.
add
(
combined_name
)
return
loaded_params
class
ChatGLM
(
ChatGLMBaseModel
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"query_key_value"
,
"dense"
,
"dense_h_to_4h"
,
"dense_4h_to_h"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
class
ChatGLMV
(
ChatGLMBaseModel
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
],
"merged_proj"
:
[
"gate_proj"
,
"dense_h_to_4h"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"query_key_value"
,
"dense"
,
"dense_h_to_4h"
,
"dense_4h_to_h"
,
# vision
"fc1"
,
"fc2"
,
"merged_proj"
,
"linear_proj"
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"transformer.encoder"
,
connector
=
"transformer.vision.linear_proj"
,
tower_model
=
"transformer.vision.transformer"
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
mm_input_mapper_for_glmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_glmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_glmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_glmv
)
class
ChatGLMForCausalLM
(
ChatGLMBaseModel
,
SupportsLoRA
,
SupportsPP
,
SupportsMultiModal
):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping
=
{}
supported_lora_modules
=
[]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__new__
(
cls
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
# Initialize VL
if
hasattr
(
config
,
"visual"
):
return
ChatGLM
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
# Initialize LLM
else
:
return
ChatGLMV
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment