Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e4ba896
Commit
1e4ba896
authored
Oct 29, 2024
by
zhuwenwen
Browse files
[Model] Add telechat-12b and GLM-4v support
parent
3bda0405
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
760 additions
and
33 deletions
+760
-33
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+3
-0
csrc/attention/static_switch_tc.h
csrc/attention/static_switch_tc.h
+6
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-0
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+16
-0
tests/models/decoder_only/vision_language/test_glm4.py
tests/models/decoder_only/vision_language/test_glm4.py
+133
-0
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+2
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-6
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+5
-3
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+282
-19
vllm/model_executor/models/glm4_vision_encoder.py
vllm/model_executor/models/glm4_vision_encoder.py
+295
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-1
vllm/model_executor/models/telechat_12B.py
vllm/model_executor/models/telechat_12B.py
+2
-2
No files found.
csrc/attention/static_switch.h
View file @
1e4ba896
...
...
@@ -48,6 +48,9 @@
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 160) { \
constexpr static int HEAD_SIZE = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
...
...
csrc/attention/static_switch_tc.h
View file @
1e4ba896
...
...
@@ -40,6 +40,12 @@
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 160) { \
constexpr static int HEAD_SIZE = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
...
...
docs/source/models/supported_models.rst
View file @
1e4ba896
...
...
@@ -224,6 +224,11 @@ Multimodal Language Models
- Image
- :code:`adept/fuyu-8b` etc.
-
* - :code:`ChatGLMModel`
- GLM-4V
- Image
- :code:`THUDM/glm-4v-9b` etc.
-
* - :code:`InternVLChatModel`
- InternVL2
- Image\ :sup:`E+`
...
...
examples/offline_inference_vision_language.py
View file @
1e4ba896
...
...
@@ -265,6 +265,21 @@ def run_mllama(question, modality):
return
llm
,
prompt
,
stop_token_ids
# GLM-4v
def
run_glm4v
(
question
:
str
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"THUDM/glm-4v-9b"
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
2048
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
enforce_eager
=
True
)
prompt
=
question
stop_token_ids
=
[
151329
,
151336
,
151338
]
return
llm
,
prompt
,
stop_token_ids
model_example_map
=
{
"llava"
:
run_llava
,
"llava-next"
:
run_llava_next
,
...
...
@@ -280,6 +295,7 @@ model_example_map = {
"qwen_vl"
:
run_qwen_vl
,
"qwen2_vl"
:
run_qwen2_vl
,
"mllama"
:
run_mllama
,
"glm4v"
:
run_glm4v
,
}
...
...
tests/models/decoder_only/vision_language/test_glm4.py
0 → 100644
View file @
1e4ba896
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
from
....utils
import
large_gpu_test
from
...utils
import
check_logprobs_close
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"What's the content of the image?"
,
"cherry_blossom"
:
"What is the season?"
,
})
models
=
[
"THUDM/glm-4v-9b"
]
target_dtype
=
"bfloat16"
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
inputs
:
List
[
Tuple
[
List
[
str
],
PromptImageInput
]],
model
:
str
,
*
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
mm_limit
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
2048
,
max_num_seqs
=
2
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
stop_token_ids
=
[
151329
,
151336
,
151338
]
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
,
stop_token_ids
=
stop_token_ids
)
for
prompts
,
images
in
inputs
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_processor
=
hf_model
.
processor
patch_padding_side
(
hf_processor
)
def
processor
(
*
args
,
text
=
""
,
images
=
None
,
**
kwargs
):
if
images
is
None
:
return
hf_processor
(
*
args
,
**
kwargs
)
return
hf_processor
.
apply_chat_template
(
[{
"role"
:
"user"
,
"image"
:
images
,
"content"
:
text
}],
add_generation_prompt
=
True
,
tokenize
=
True
,
return_dict
=
True
,
**
kwargs
,
)
hf_model
.
processor
=
processor
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
transformer
.
output_layer
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
,
)
for
prompts
,
images
in
inputs
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
mm_limit
=
1
,
tensor_parallel_size
=
1
,
)
\ No newline at end of file
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
1e4ba896
...
...
@@ -22,7 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def
__init__
(
self
):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_lm_
t
n
=
os
.
environ
.
get
(
'LM_
T
N'
)
==
'1'
self
.
use_lm_
n
n
=
os
.
environ
.
get
(
'LM_
N
N'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
...
...
@@ -42,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_llama_nn
and
not
self
.
use_lm_
t
n
:
if
self
.
use_llama_nn
and
self
.
use_lm_
n
n
:
if
bias
is
not
None
:
if
len
(
x
.
shape
)
==
2
:
return
torch
.
addmm
(
bias
,
x
,
layer
.
weight
)
...
...
vllm/model_executor/model_loader/utils.py
View file @
1e4ba896
...
...
@@ -22,18 +22,18 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
vis
ual
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
vis
ions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
architectures
==
[
'QWenLMHeadModel'
]
and
vis
ual
!=
[]:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
vis
ions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'LlamaForCausalLM'
]
:
os
.
environ
[
'LM_
T
N'
]
=
'
1
'
if
architectures
==
[
'BloomForCausalLM'
]
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_
N
N'
]
=
'
0
'
else
:
os
.
environ
[
'LM_
T
N'
]
=
'
0
'
os
.
environ
[
'LM_
N
N'
]
=
'
1
'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
...
...
@@ -50,7 +50,7 @@ def get_model_architecture(
os
.
environ
[
'AWQ_PAD'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_
T
N'
]
=
'
1
'
os
.
environ
[
'LM_
N
N'
]
=
'
0
'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'AWQ_PAD'
]
=
'0'
...
...
vllm/model_executor/models/__init__.py
View file @
1e4ba896
...
...
@@ -15,8 +15,7 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
# ChatGLMModel supports multimodal
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
...
...
@@ -53,6 +52,7 @@ _GENERATION_MODELS = {
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2VLForConditionalGeneration"
:
...
...
@@ -82,6 +82,8 @@ _MULTIMODAL_MODELS = {
(
"blip2"
,
"Blip2ForConditionalGeneration"
),
"ChameleonForConditionalGeneration"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"FuyuForCausalLM"
:
(
"fuyu"
,
"FuyuForCausalLM"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
...
...
vllm/model_executor/models/chatglm.py
View file @
1e4ba896
# coding=utf-8
# Adapted from
# https://github.com/THUDM/
ChatGLM2-6B
# https://github.com/THUDM/
GLM-4
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
argparse
import
Namespace
from
array
import
array
from
typing
import
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
import
torch
from
PIL
import
Image
from
torch
import
nn
from
torch.nn
import
LayerNorm
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -26,14 +31,197 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.glm4_vision_encoder
import
EVA2CLIPModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalInputs
)
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
logger
=
init_logger
(
__name__
)
def
calculate_image_placeholder
(
vision_config
):
return
(
vision_config
[
"image_size"
]
//
vision_config
[
"patch_size"
]
//
2
)
**
2
def
mm_input_mapper_for_glmv
(
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
)
->
Dict
:
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
if
tokenizer
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
try
:
raw_batch_data
=
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"image"
:
data
}],
add_generation_prompt
=
True
,
tokenize
=
True
,
return_tensors
=
"pt"
,
return_dict
=
True
).
data
except
Exception
:
logger
.
error
(
"Failed to process image (%s)"
,
data
)
raise
pixel_values
=
raw_batch_data
[
'images'
]
return
MultiModalInputs
({
'pixel_values'
:
pixel_values
})
def
merge_glm_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
boi_token_id
:
int
,
eoi_token_id
:
int
,
)
->
torch
.
Tensor
:
boi_positions
=
(
input_ids
==
boi_token_id
).
nonzero
(
as_tuple
=
True
)[
0
]
eoi_positions
=
(
input_ids
==
eoi_token_id
).
nonzero
(
as_tuple
=
True
)[
0
]
mask
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
bool
)
for
boi_pos
,
eoi_pos
in
zip
(
boi_positions
,
eoi_positions
):
assert
boi_pos
<
eoi_pos
mask
[
boi_pos
:
eoi_pos
+
1
]
=
True
inputs_embeds
[
mask
]
=
vision_embeddings
.
view
(
-
1
,
vision_embeddings
.
shape
[
-
1
])
return
inputs_embeds
class
GLMImagePixelInputs
(
TypedDict
):
pixel_values
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def
get_max_glmv_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
ChatGLMConfig
)
vision_config
=
getattr
(
hf_config
,
'vision_config'
,
None
)
if
vision_config
is
None
:
return
1
elif
isinstance
(
vision_config
,
dict
):
return
calculate_image_placeholder
(
vision_config
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
dummy_data_for_glmv
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
)
->
Tuple
[
SequenceData
,
Optional
[
MultiModalDataDict
]]:
hf_config
=
ctx
.
get_hf_config
(
ChatGLMConfig
)
vision_config
=
getattr
(
hf_config
,
'vision_config'
,
None
)
if
vision_config
is
None
:
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
seq_len
)
seq_data
=
SequenceData
(
token_ids
)
return
seq_data
,
None
elif
isinstance
(
vision_config
,
dict
):
image_size
=
vision_config
[
"image_size"
]
image_placeholder_length
=
calculate_image_placeholder
(
vision_config
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
boi_token_id
]
+
[
0
]
*
image_placeholder_length
+
[
hf_config
.
eoi_token_id
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
(
seq_len
-
image_placeholder_length
-
2
))
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
Image
.
new
(
"RGB"
,
(
image_size
,
image_size
),
color
=
0
)
}
return
seq_data
,
mm_data
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
find_all_positions
(
input_ids
:
List
[
int
],
target
:
int
)
->
List
[
int
]:
return
[
index
for
index
,
value
in
enumerate
(
input_ids
)
if
value
==
target
]
def
input_processor_for_glmv
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
hf_config
=
ctx
.
get_hf_config
(
ChatGLMConfig
)
vision_config
=
getattr
(
hf_config
,
'vision_config'
,
None
)
if
vision_config
is
None
:
return
llm_inputs
elif
isinstance
(
vision_config
,
dict
):
image_placeholder_length
=
calculate_image_placeholder
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
input_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
position_ids
=
llm_inputs
.
get
(
"position_ids"
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
model
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
)
try
:
raw_batch_data
=
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"image"
:
llm_inputs
[
'multi_modal_data'
][
"image"
],
"content"
:
llm_inputs
[
'prompt'
]
}],
add_generation_prompt
=
True
,
tokenize
=
True
,
return_tensors
=
"pt"
,
return_dict
=
True
).
data
except
Exception
:
logger
.
error
(
"Failed to process content (%s)"
,
llm_inputs
[
'prompt'
])
raise
input_ids
=
raw_batch_data
[
'input_ids'
][
0
].
tolist
()
if
position_ids
is
None
:
position_ids
=
list
(
range
(
len
(
input_ids
)))
boi_token_id
=
hf_config
.
boi_token_id
eoi_token_id
=
hf_config
.
eoi_token_id
boi_positions
=
find_all_positions
(
input_ids
,
boi_token_id
)
eoi_positions
=
find_all_positions
(
input_ids
,
eoi_token_id
)
assert
len
(
boi_positions
)
==
len
(
eoi_positions
)
new_input_ids
=
[]
new_position_ids
=
[]
final_processed_position
=
0
final_processed_position
=
0
for
boi_position
,
eoi_position
in
zip
(
boi_positions
,
eoi_positions
):
assert
boi_position
<
eoi_position
new_input_ids
.
extend
(
input_ids
[
final_processed_position
:
boi_position
+
1
])
new_position_ids
.
extend
(
list
(
range
(
final_processed_position
,
boi_position
+
1
)))
new_input_ids
.
extend
([
input_ids
[
boi_position
+
1
]]
*
image_placeholder_length
)
new_position_ids
.
extend
([
boi_position
+
1
]
*
image_placeholder_length
)
final_processed_position
=
eoi_position
new_input_ids
.
extend
(
input_ids
[
final_processed_position
:])
new_position_ids
.
extend
(
list
(
range
(
final_processed_position
,
len
(
input_ids
))))
assert
len
(
new_input_ids
)
==
len
(
new_position_ids
)
llm_inputs
[
"prompt_token_ids"
]
=
new_input_ids
llm_inputs
[
"position_ids"
]
=
new_position_ids
return
llm_inputs
class
GLMAttention
(
nn
.
Module
):
...
...
@@ -306,8 +494,11 @@ class ChatGLMModel(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
...
...
@@ -318,26 +509,72 @@ class ChatGLMModel(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
)
vision_config_flag
=
getattr
(
config
,
'vision_config'
,
None
)
if
vision_config_flag
is
not
None
:
self
.
vision_config
=
Namespace
(
**
config
.
vision_config
)
self
.
vision
=
EVA2CLIPModel
(
self
.
config
,
quant_config
)
else
:
self
.
vision
=
None
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
GLMImagePixelInputs
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
not
None
and
self
.
vision
is
not
None
:
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
if
pixel_values
.
ndim
>
2
:
pixel_values
=
torch
.
concat
(
list
(
pixel_values
))
elif
isinstance
(
pixel_values
,
list
):
return
torch
.
concat
(
pixel_values
)
else
:
raise
TypeError
(
"""pixel_values must be a torch.Tensor
or a list of torch.Tensor
"""
)
return
GLMImagePixelInputs
(
pixel_values
=
pixel_values
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position
_id
s
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
[
"pixel_values"
]
is
not
None
:
pixel_values
=
image_input
[
"pixel_values"
].
to
(
dtype
=
inputs_embeds
.
dtype
)
image_embeds
=
self
.
vision
(
pixel_values
)
boi_token_id
=
self
.
config
.
boi_token_id
eoi_token_id
=
self
.
config
.
eoi_token_id
inputs_embeds
=
merge_glm_vision_embeddings
(
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
vision_embeddings
=
image_embeds
,
boi_token_id
=
boi_token_id
,
eoi_token_id
=
eoi_token_id
)
# Run encoder.
hidden_states
=
self
.
encoder
(
hidden_states
=
inputs_embeds
,
position_ids
=
position
_id
s
,
position_ids
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
)
return
hidden_states
class
ChatGLMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
mm_input_mapper_for_glmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_glmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_glmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_glmv
)
class
ChatGLMForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
...
...
@@ -355,6 +592,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
...
...
@@ -363,6 +601,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
multimodal_config
=
multimodal_config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
...
...
@@ -384,16 +623,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
**
kwargs
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
**
kwargs
)
return
hidden_states
def
compute_logits
(
...
...
@@ -414,8 +652,24 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict
:
Dict
[
str
,
Dict
[
str
,
Optional
[
torch
.
Tensor
]]]
=
{
"transformer.vision.linear_proj.merged_proj.weight"
:
{
"transformer.vision.linear_proj.gate_proj.weight"
:
None
,
"transformer.vision.linear_proj.dense_h_to_4h.weight"
:
None
,
}
}
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
is_weight_to_be_merge
=
False
for
_
,
merged_weight_dict
in
merged_weights_dict
.
items
():
if
name
in
merged_weight_dict
:
assert
merged_weight_dict
[
name
]
is
None
merged_weight_dict
[
name
]
=
loaded_weight
is_weight_to_be_merge
=
True
if
is_weight_to_be_merge
:
continue
if
"rotary_pos_emb.inv_freq"
in
name
:
continue
if
"word_embeddings"
in
name
:
...
...
@@ -428,6 +682,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
combined_name
,
merged_weight_dict
in
merged_weights_dict
.
items
():
if
combined_name
in
params_dict
:
param
=
params_dict
[
combined_name
]
combined_weight
=
torch
.
cat
(
list
(
merged_weight_dict
.
values
()),
dim
=
0
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
combined_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
...
...
vllm/model_executor/models/glm4_vision_encoder.py
0 → 100644
View file @
1e4ba896
# coding=utf-8
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from
argparse
import
Namespace
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
PatchEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
proj
=
nn
.
Conv2d
(
config
.
in_channels
,
config
.
hidden_size
,
kernel_size
=
config
.
patch_size
,
stride
=
config
.
patch_size
)
self
.
cls_embedding
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
config
.
hidden_size
))
self
.
position_embedding
=
nn
.
Embedding
(
config
.
num_positions
,
config
.
hidden_size
)
def
forward
(
self
,
images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images
=
images
.
to
(
self
.
proj
.
weight
.
device
)
x
=
self
.
proj
(
images
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
cls_token
=
self
.
cls_embedding
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
+=
self
.
position_embedding
.
weight
.
unsqueeze
(
0
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_rank
=
config
.
num_heads
//
self
.
tp_size
self
.
head_dim
=
config
.
hidden_size
//
config
.
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
query_key_value
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
head_dim
,
config
.
num_heads
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
L
,
_
=
x
.
shape
qkv
,
_
=
self
.
query_key_value
(
x
)
# B, L, 3 * H * D
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
k
=
k
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
v
=
v
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.
,
is_causal
=
False
)
# output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output
,
_
=
self
.
dense
(
out
.
transpose
(
1
,
2
).
reshape
(
B
,
L
,
-
1
))
output
=
self
.
output_dropout
(
output
)
return
output
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc1
(
x
)
x
=
self
.
activation_fn
(
x
)
x
,
_
=
self
.
fc2
(
x
)
return
x
class
TransformerLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
input_layernorm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
Attention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
MLP
(
config
,
quant_config
=
quant_config
)
self
.
post_attention_layernorm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
):
attention_input
=
hidden_states
attention_output
=
self
.
input_layernorm
(
self
.
attention
(
attention_input
))
hidden_states
=
attention_input
+
attention_output
mlp_input
=
hidden_states
mlp_output
=
self
.
post_attention_layernorm
(
self
.
mlp
(
mlp_input
))
output
=
mlp_input
+
mlp_output
return
output
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
TransformerLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
hidden_states
):
for
layer_module
in
self
.
layers
:
hidden_states
=
layer_module
(
hidden_states
)
return
hidden_states
class
GLU
(
nn
.
Module
):
def
__init__
(
self
,
config
,
in_features
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super
().
__init__
()
self
.
linear_proj
=
ReplicatedLinear
(
in_features
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
act1
=
nn
.
GELU
()
self
.
act2
=
SiluAndMul
()
self
.
merged_proj
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
def
forward
(
self
,
x
):
x
,
_
=
self
.
linear_proj
(
x
)
x
=
self
.
act1
(
self
.
norm1
(
x
))
x
,
_
=
self
.
merged_proj
(
x
)
x
=
self
.
act2
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
return
x
class
EVA2CLIPModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
vision_config
=
Namespace
(
**
config
.
vision_config
)
self
.
patch_embedding
=
PatchEmbedding
(
vision_config
)
self
.
transformer
=
Transformer
(
vision_config
,
quant_config
=
quant_config
)
self
.
linear_proj
=
GLU
(
config
,
in_features
=
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
vision_config
.
hidden_size
,
out_channels
=
config
.
hidden_size
,
kernel_size
=
2
,
stride
=
2
)
self
.
boi
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
self
.
eoi
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
self
.
scaling_factor
=
vision_config
.
scaling_factor
def
forward
(
self
,
images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x
=
self
.
patch_embedding
(
images
)
x
=
self
.
transformer
(
x
)
x
=
x
[:,
1
:]
b
,
s
,
h
=
x
.
shape
grid_size
=
int
(
s
**
0.5
)
x
=
x
.
view
(
b
,
grid_size
,
grid_size
,
h
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
conv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
linear_proj
(
x
)
boi
=
self
.
boi
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
eoi
=
self
.
eoi
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
boi
,
x
,
eoi
),
dim
=
1
)
x
=
x
/
self
.
scaling_factor
return
x
\ No newline at end of file
vllm/model_executor/models/llama.py
View file @
1e4ba896
...
...
@@ -573,9 +573,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
"mlp.down_proj.weight"
# "lm_head.weight"
]
if
self
.
use_lm_nn
:
lay_key_words
.
append
(
"lm_head.weight"
)
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.qkv_proj.weight"
]
...
...
vllm/model_executor/models/telechat_12B.py
View file @
1e4ba896
...
...
@@ -45,13 +45,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
,
print_warning_once
from
.interfaces
import
SupportsLoRA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment