Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
963 additions
and
748 deletions
+963
-748
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+30
-80
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+13
-19
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+49
-17
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+12
-17
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+20
-20
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+0
-3
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+54
-48
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+24
-10
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+17
-27
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+28
-21
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+36
-68
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+47
-42
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+37
-92
vllm/model_executor/models/glm4.py
vllm/model_executor/models/glm4.py
+313
-0
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+19
-23
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+76
-69
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+43
-36
vllm/model_executor/models/granitemoeshared.py
vllm/model_executor/models/granitemoeshared.py
+43
-36
vllm/model_executor/models/grok1.py
vllm/model_executor/models/grok1.py
+101
-96
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+1
-24
No files found.
vllm/model_executor/models/aya_vision.py
View file @
9c4ecf15
...
...
@@ -20,22 +20,21 @@ from vllm.jsontree import json_map_leaves
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalFieldConfig
,
PromptReplacement
,
PromptUpdate
,
encode_token
s
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
PromptUpdateDetail
s
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
class
AyaVisionImagePixelInputs
(
TypedDict
):
...
...
@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
AyaVisionMultiModalProjector
(
nn
.
Module
):
...
...
@@ -125,32 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def
get_image_processor
(
self
)
->
GotOcr2ImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
image_processor
=
hf_processor
.
image_processor
image_size
=
self
.
get_image_size_with_most_features
()
tokenizer
=
hf_processor
.
tokenizer
num_patches
=
self
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
image_string
=
hf_processor
.
_prompt_split_image
(
num_patches
)
x
=
encode_tokens
(
tokenizer
,
image_string
,
add_special_tokens
=
False
,
)
return
len
(
x
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
...
...
@@ -180,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
class
AyaVisionDummyInputsBuilder
(
BaseDummyInputsBuilder
[
AyaVisionProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
image_size
=
\
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
image_size
.
width
,
height
=
image_size
.
height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
AyaVisionMultiModalProcessor
(
...
...
@@ -221,7 +188,6 @@ class AyaVisionMultiModalProcessor(
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_processor
=
hf_processor
.
image_processor
hf_config
=
self
.
info
.
get_hf_config
()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if
(
images
:
=
mm_data
.
get
(
"images"
))
is
not
None
and
'<image>'
in
prompt
:
...
...
@@ -234,6 +200,7 @@ class AyaVisionMultiModalProcessor(
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
]
num_patches
=
[
self
.
info
.
get_num_patches
(
image_width
=
image_size
.
width
,
...
...
@@ -243,20 +210,6 @@ class AyaVisionMultiModalProcessor(
max_patches
=
image_processor
.
max_patches
)
for
image_size
in
image_sizes
]
image_tokens_list
=
[
hf_processor
.
_prompt_split_image
(
num_patch
)
for
num_patch
in
num_patches
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_token_ids
=
[
tokenizer
.
encode
(
image_tokens
,
add_special_tokens
=
False
)
for
image_tokens
in
image_tokens_list
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
hf_config
.
image_token_index
for
image_repl_tokens
in
image_token_ids
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
processed_outputs
[
"num_patches"
]
=
torch
.
tensor
(
num_patches
)
return
processed_outputs
...
...
@@ -271,7 +224,6 @@ class AyaVisionMultiModalProcessor(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
...
...
@@ -283,6 +235,7 @@ class AyaVisionMultiModalProcessor(
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
img_patch_token
=
hf_processor
.
img_patch_token
image_processor
=
hf_processor
.
image_processor
def
get_replacement
(
item_idx
:
int
):
...
...
@@ -294,8 +247,11 @@ class AyaVisionMultiModalProcessor(
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
return
hf_processor
.
_prompt_split_image
(
num_patches
=
num_patches
)
max_patches
=
image_processor
.
max_patches
,
)
repl
=
hf_processor
.
_prompt_split_image
(
num_patches
=
num_patches
)
return
PromptUpdateDetails
.
select_text
(
repl
,
img_patch_token
)
return
[
PromptReplacement
(
...
...
@@ -424,7 +380,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
**
kwargs
:
object
)
->
Optional
[
AyaVisionImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Aya Vision does not support image_embeds."
...
...
@@ -436,30 +391,25 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
AyaVisionImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
num_patches
,
embed_is_patch
=
embed_is_patch
,
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
image_features
=
self
.
_process_image_input
(
image_input
,
**
kwargs
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
return
self
.
_process_image_input
(
image_input
,
**
kwargs
)
def
get_input_embeddings
(
self
,
...
...
@@ -471,9 +421,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
select_patch_features
(
multimodal_embeddings
)
,
placeholder_token_id
=
self
.
config
.
image_token_index
)
multimodal_embeddings
=
multimodal_embeddings
,
placeholder_token_id
=
self
.
config
.
image_token_index
,
)
return
inputs_embeds
...
...
vllm/model_executor/models/bamba.py
View file @
9c4ecf15
...
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
MambaMixer2
,
extra_groups_for_head_shards
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim
=
config
.
mamba_d_head
,
rms_norm_eps
=
config
.
rms_norm_eps
,
activation
=
config
.
hidden_act
,
chunk_size
=
config
.
mamba_chunk_size
,
quant_config
=
quant_config
)
self
.
feed_forward
=
BambaMLP
(
config
,
quant_config
=
quant_config
)
...
...
@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
sequence_idx
)
mamba2_metadata
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
...
...
@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
:
BambaConfig
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
...
...
@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx
=
None
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
for
i
,
(
srt
,
end
)
in
enumerate
(
zip
(
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
[
1
:],
)):
seq_idx
[
srt
:
end
]
=
i
seq_idx
.
unsqueeze_
(
0
)
mamba2_metadata
=
prepare_mamba2_metadata
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
input_ids
=
input_ids
,
attn_metadata
=
attn_metadata
,
)
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
...
...
@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
hidden_states
=
hidden_states
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
sequence_idx
=
seq_idx
,
mamba2_metadata
=
mamba2_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/bert.py
View file @
9c4ecf15
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
" is supported"
)
if
self
.
position_embedding_type
==
"absolute"
:
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
elif
self
.
position_embedding_type
==
"rotary"
:
self
.
position_embeddings
=
None
self
.
position_ids
=
None
else
:
raise
ValueError
(
"Only 'absolute' and 'rotary' "
+
"position_embedding_type is supported"
)
def
forward
(
self
,
...
...
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
...
...
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
inputs_embeds
+
token_type_embeddings
if
self
.
position_embedding_type
==
"absolute"
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
+=
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
...
...
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@
support_torch_compile
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
...
...
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
)
return
hidden_states
...
...
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps
=
config
.
layer_norm_eps
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
intermediate
=
BertIntermediate
(
...
...
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
attn_output
=
self
.
attention
(
hidden_states
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
attn_output
=
self
.
attention
(
positions
,
hidden_states
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
...
...
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads
=
num_attention_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.output"
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
...
...
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
)
self_output
=
self
.
self
(
positions
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
...
...
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
if
rotary_kwargs
:
self
.
rotary_emb
=
get_rope
(
**
rotary_kwargs
)
else
:
self
.
rotary_emb
=
None
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_dim
,
scale
=
self
.
scaling
,
...
...
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
output
=
self
.
attn
(
q
,
k
,
v
)
return
output
...
...
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
embedding_class
:
type
=
BertEmbedding
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
add_pooling_layer
:
bool
=
False
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
...
...
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
)
return
self
.
encoder
(
position_ids
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
...
...
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
self
.
_build_pooler
(
pooler_config
)
...
...
vllm/model_executor/models/blip2.py
View file @
9c4ecf15
...
...
@@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.blip
import
BlipVisionModel
...
...
@@ -406,13 +407,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
return
hf_config
.
num_query_tokens
...
...
@@ -420,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
class
Blip2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Blip2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
max_image_size
=
vision_config
.
image_size
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
max_image_size
,
height
=
max_image_size
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
Blip2MultiModalProcessor
(
BaseMultiModalProcessor
[
Blip2ProcessingInfo
]):
...
...
@@ -627,6 +619,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return
self
.
language_projection
(
query_output
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/chameleon.py
View file @
9c4ecf15
...
...
@@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
...
...
@@ -64,13 +65,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
processor
=
self
.
get_hf_processor
()
return
processor
.
image_seq_length
...
...
@@ -79,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
class
ChameleonDummyInputsBuilder
(
BaseDummyInputsBuilder
[
ChameleonProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
config
=
self
.
info
.
get_hf_config
()
width
=
height
=
config
.
vq_config
.
resolution
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
width
,
height
=
height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
"<image>"
*
num_images
,
mm_data
=
mm_data
,
)
class
ChameleonMultiModalProcessor
(
BaseMultiModalProcessor
[
ChameleonProcessingInfo
]):
...
...
@@ -162,9 +159,9 @@ class ChameleonMultiModalProcessor(
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
PromptUpdateDetails
(
full
=
(
[
image_start_id
]
+
image_tokens
+
[
image_end_id
]
)
,
features
=
image_token
s
,
replacement
=
PromptUpdateDetails
.
select_token_id
(
[
image_start_id
]
+
image_tokens
+
[
image_end_id
],
embed_token_id
=
image_token
_id
,
),
)
]
...
...
@@ -988,6 +985,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/clip.py
View file @
9c4ecf15
...
...
@@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
)
->
int
:
return
self
.
get_patch_grid_length
()
**
2
+
1
def
get_max_image_tokens
(
self
)
->
int
:
return
self
.
get_patch_grid_length
()
**
2
+
1
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
...
...
vllm/model_executor/models/deepseek.py
View file @
9c4ecf15
...
...
@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/deepseek_v2.py
View file @
9c4ecf15
...
...
@@ -163,14 +163,16 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
...
...
@@ -502,6 +504,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
self
.
layer_idx
=
layer_idx
if
model_config
.
use_mla
:
attn_cls
=
DeepseekV2MLAAttention
else
:
...
...
@@ -564,19 +567,30 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
)
# Fully Connected
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
#
This is a special case to avoid FP16 overflow
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
#
rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
# This is a special case to avoid FP16 overflow
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
...
...
vllm/model_executor/models/deepseek_vl2.py
View file @
9c4ecf15
...
...
@@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModal
FieldConfig
,
MultiModal
Kwargs
,
NestedTensors
)
from
vllm.multimodal.inputs
import
(
MultiModal
DataDict
,
MultiModal
FieldConfig
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
MlpProjectorConfig
,
...
...
@@ -168,47 +168,34 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width
=
x
[
1
],
image_height
=
x
[
0
]))
return
ImageSize
(
width
=
width
,
height
=
height
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
max_image_size
=
self
.
get_image_size_with_most_features
()
max_image_tokens
=
self
.
get_num_image_tokens
(
image_height
=
max_image_size
.
height
,
image_width
=
max_image_size
.
width
,
cropping
=
num_images
<=
2
)
return
{
"image"
:
max_image_tokens
}
class
DeepseekVL2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
DeepseekVL2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
max_image_size
=
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
max_image_size
.
width
,
height
=
max_image_size
.
height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
DeepseekVL2MultiModalProcessor
(
BaseMultiModalProcessor
[
DeepseekVL2ProcessingInfo
]):
...
...
@@ -604,6 +591,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
self
.
_pixel_values_to_embedding
(
pixel_values
=
pixel_values
,
images_spatial_crop
=
images_spatial_crop
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/florence2.py
View file @
9c4ecf15
...
...
@@ -10,7 +10,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BartTokenizer
,
BatchFeature
,
PretrainedConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartScaledWordEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.parse
import
MultiModalDataDict
,
MultiModalDataItems
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
...
...
@@ -764,42 +765,33 @@ class Florence2ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_
max
_image_tokens
(
self
)
->
int
:
def
get_
num
_image_tokens
(
self
)
->
int
:
processor_config
=
self
.
ctx
.
get_hf_image_processor_config
()
return
processor_config
[
"image_seq_length"
]
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
class
Florence2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Florence2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
=
target_height
=
self
.
info
.
get_hf_config
().
projection_dim
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
Florence2MultiModalProcessor
(
EncDecMultiModalProcessor
[
Florence2ProcessingInfo
]):
...
...
@@ -826,6 +818,18 @@ class Florence2MultiModalProcessor(
)
->
Union
[
str
,
list
[
int
]]:
return
[
self
.
info
.
get_hf_config
().
eos_token_id
]
def
_apply_hf_processor_tokens_only
(
self
,
prompt_tokens
:
list
[
int
],
)
->
list
[
int
]:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
:
BartTokenizer
=
hf_processor
.
tokenizer
prompt_text
=
tokenizer
.
decode
(
prompt_tokens
)
# convert task tokens to prompt
prompt_text
=
hf_processor
.
_construct_prompts
([
prompt_text
])[
0
]
prompt_tokens
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)
return
prompt_tokens
def
_call_hf_processor
(
self
,
prompt
:
str
,
...
...
@@ -859,7 +863,7 @@ class Florence2MultiModalProcessor(
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
pad_token_id
=
hf_config
.
pad_token_id
num_image_tokens
=
self
.
info
.
get_
max
_image_tokens
()
num_image_tokens
=
self
.
info
.
get_
num
_image_tokens
()
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
return
[
...
...
@@ -1038,6 +1042,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values
=
image_input
[
"data"
]
return
self
.
_encode_image
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/fuyu.py
View file @
9c4ecf15
...
...
@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
import
torch
import
torch.nn
as
nn
...
...
@@ -31,19 +31,19 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID
=
71011
...
...
@@ -66,14 +66,6 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`.
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
FuyuProcessingInfo
(
BaseProcessingInfo
):
...
...
@@ -89,21 +81,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
max_ncols
,
max_nrows
=
self
.
get_image_feature_grid_size
(
image_width
=
target_width
,
image_height
=
target_height
,
)
max_image_tokens
=
(
max_ncols
+
1
)
*
max_nrows
return
{
"image"
:
max_image_tokens
}
def
get_image_feature_grid_size
(
self
,
*
,
...
...
@@ -128,6 +105,19 @@ class FuyuProcessingInfo(BaseProcessingInfo):
nrows
=
math
.
ceil
(
image_height
/
patch_height
)
return
ncols
,
nrows
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
ncols
,
nrows
=
self
.
get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
)
return
ncols
*
nrows
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
return
ImageSize
(
width
=
image_processor
.
size
[
"width"
],
...
...
@@ -136,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo):
class
FuyuDummyInputsBuilder
(
BaseDummyInputsBuilder
[
FuyuProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
FuyuMultiModalProcessor
(
BaseMultiModalProcessor
[
FuyuProcessingInfo
]):
...
...
@@ -192,19 +180,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs
[
"image_patches"
]
=
image_patches
[
0
]
# get patch grid size for each image
embed_is_patch
=
[]
for
image
in
images
:
ncols
,
nrows
=
self
.
info
.
get_image_feature_grid_size
(
image_width
=
image
.
width
,
image_height
=
image
.
height
,
)
mask
=
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
embed_is_patch
.
append
(
mask
)
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
def
_apply_hf_processor_tokens_only
(
...
...
@@ -224,8 +199,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
))
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
))
def
_get_prompt_updates
(
self
,
...
...
@@ -252,9 +226,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens
=
([
_IMAGE_TOKEN_ID
]
*
ncols
+
[
_NEWLINE_TOKEN_ID
])
*
nrows
return
PromptUpdateDetails
(
full
=
image_tokens
+
[
bos_token_id
],
features
=
image_tokens
,
return
PromptUpdateDetails
.
select_token_id
(
image_tokens
+
[
bos_token_id
],
embed_token_id
=
_IMAGE_TOKEN_ID
,
)
return
[
...
...
@@ -329,20 +303,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of image patches. "
f
"Got type:
{
type
(
image_patches
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
image_patches_flat
=
flatten_bn
(
image_patches
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
flat_data
=
self
.
_validate_pixel_values
(
flatten_bn
(
image_patches_flat
,
concat
=
True
)),
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
embed_is_patch
=
embed_is_patch
,
)
return
None
...
...
@@ -358,18 +325,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
return
self
.
_process_image_input
(
image_input
)
def
get_input_embeddings
(
self
,
...
...
@@ -379,8 +344,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
),
_IMAGE_TOKEN_ID
)
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
_IMAGE_TOKEN_ID
,
)
return
inputs_embeds
def
forward
(
...
...
vllm/model_executor/models/gemma.py
View file @
9c4ecf15
...
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
...
...
@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/gemma3_mm.py
View file @
9c4ecf15
...
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
# yapf: disable
...
...
@@ -25,10 +26,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptTargetMatch
,
PromptUpdate
,
PromptUpdateDetails
,
encode_tokens
,
find_mm_placeholders
,
find_mm_placeholders
,
replace_token_matches
)
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
...
@@ -36,7 +37,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
logger
=
init_logger
(
__name__
)
...
...
@@ -54,14 +54,6 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
...
...
@@ -77,13 +69,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
_resolve_image_kwargs
(
self
,
processor
:
Gemma3Processor
,
...
...
@@ -183,7 +168,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image
_token
=
processor
.
boi_token
boi
_token
=
processor
.
boi_token
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
...
...
@@ -192,19 +177,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
)
if
num_crops
==
0
:
image_text
=
image
_token
image_text
=
boi
_token
else
:
crops_image_tokens
=
" "
.
join
(
image_token
for
_
in
range
(
num_crops
))
crops_image_tokens
=
" "
.
join
(
boi_token
for
_
in
range
(
num_crops
))
image_text
=
(
f
"Here is the original image
{
image
_token
}
and here are some "
f
"Here is the original image
{
boi
_token
}
and here are some "
f
"crops to help you see better
{
crops_image_tokens
}
"
)
repl_full
=
image_text
.
replace
(
image
_token
,
repl_full
=
image_text
.
replace
(
boi
_token
,
processor
.
full_image_sequence
)
repl_features
=
repl_full
.
strip
(
"
\n
"
)
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
tokenizer
=
processor
.
tokenizer
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
tokenizer
.
image_token
]
return
PromptUpdateDetails
.
select_token_id
(
repl_full
,
image_token_id
)
def
get_num_image_tokens
(
self
,
...
...
@@ -213,19 +200,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
processor
:
Optional
[
Gemma3Processor
],
)
->
int
:
tokenizer
=
self
.
get_tokenizer
()
image_repl
=
self
.
get_image_repl
(
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
)
image_seq_len
=
processor
.
image_seq_length
image_repl_tokens
=
encode_tokens
(
tokenizer
,
image_repl
.
features
,
add_special_tokens
=
False
,
)
return
len
(
image_repl_tokens
)
return
(
num_crops
+
1
)
*
image_seq_len
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
...
...
@@ -237,43 +222,34 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
# Result in the max possible feature size (h:w = max_num_crops:1)
return
ImageSize
(
height
=
50
*
max_num_crops
,
width
=
50
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
class
Gemma3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Gemma3ProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
class
Gemma3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Gemma3ProcessingInfo
]):
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
boi_token
return
image_token
*
num_images
def
get_dummy_
processor_inputs
(
def
get_dummy_
mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
boi_token
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
Gemma3MultiModalProcessor
(
BaseMultiModalProcessor
[
Gemma3ProcessingInfo
]):
...
...
@@ -301,28 +277,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_repl_features
=
[
self
.
info
.
get_image_repl
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
).
features
for
size
in
image_sizes
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
tokenizer
.
image_token
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
image_token_id
for
image_repl_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
num_crops
=
[
self
.
info
.
get_num_crops
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
...
...
@@ -344,7 +298,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
+
1
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
...
...
@@ -454,6 +407,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx
=
p
.
item_idx
,
start_idx
=
repl_orig_idxs
[
p
.
start_idx
],
tokens
=
p
.
tokens
,
is_embed
=
p
.
is_embed
,
)
for
p
in
placeholders
]
for
modality
,
placeholders
in
repls
.
items
()
...
...
@@ -572,7 +526,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
if
pixel_values
is
None
:
...
...
@@ -586,19 +539,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
Gemma3ImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
num_crops
+
1
,
embed_is_patch
=
embed_is_patch
,
)
def
_image_pixels_to_features
(
...
...
@@ -629,18 +576,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e
.
flatten
(
0
,
1
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
return
self
.
_process_image_input
(
image_input
)
def
get_input_embeddings
(
self
,
...
...
@@ -652,7 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
)
return
inputs_embeds
...
...
vllm/model_executor/models/glm4.py
0 → 100644
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Zhipu AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Glm4Config
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.llama
import
LlamaMLP
as
Glm4MLP
from
.llama
import
LlamaModel
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
class
Glm4Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4Config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
Optional
[
int
]
=
None
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
0.5
)
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
rotary_dim
=
int
(
partial_rotary_factor
*
self
.
head_dim
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
partial_rotary_factor
=
partial_rotary_factor
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
attn_type
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Glm4DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
self_attn
=
Glm4Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
qkv_bias
=
getattr
(
config
,
'attention_bias'
,
False
),
head_dim
=
getattr
(
config
,
'head_dim'
,
None
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_type
=
AttentionType
.
DECODER
,
)
self
.
mlp
=
Glm4MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_self_attn_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_mlp_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
=
self
.
post_self_attn_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_mlp_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
ALL_DECODER_LAYER_TYPES
=
{
"attention"
:
Glm4DecoderLayer
,
}
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
})
class
Glm4Model
(
LlamaModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
Glm4DecoderLayer
)
class
Glm4ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Glm4Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/glm4v.py
View file @
9c4ecf15
...
...
@@ -12,7 +12,7 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
PreTrainedTokenizer
,
TensorType
from
transformers
import
BatchFeature
,
PreTrainedTokenizer
,
TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
...
...
@@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BatchFeature
,
MultiModalFieldConfig
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
ChatGLMConfig
...
...
@@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_feature_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
...
...
@@ -454,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
class
GLM4VDummyInputsBuilder
(
BaseDummyInputsBuilder
[
GLM4VProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
base_text
=
"<|begin_of_image|><|endoftext|><|end_of_image|>"
return
base_text
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
target_width
=
target_height
=
vision_config
[
"image_size"
]
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
base_text
=
"<|begin_of_image|><|endoftext|><|end_of_image|>"
return
ProcessorInputs
(
prompt_text
=
base_text
*
num_images
,
mm_data
=
mm_data
,
)
class
GLM4VMultiModalProcessor
(
BaseMultiModalProcessor
[
GLM4VProcessingInfo
]):
...
...
@@ -596,6 +589,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return
self
.
transformer
.
vision
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
transformer
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/granite.py
View file @
9c4ecf15
...
...
@@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
)
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
)
class
GraniteMLP
(
nn
.
Module
):
...
...
@@ -260,6 +260,7 @@ class GraniteModel(nn.Module):
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
...
@@ -321,6 +322,65 @@ class GraniteModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GraniteForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
...
...
@@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
skip_prefixes
=
[
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
:
skip_prefixes
.
append
(
"lm_head.weight"
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/granitemoe.py
View file @
9c4ecf15
...
...
@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
class
GraniteMoeMoE
(
nn
.
Module
):
...
...
@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Required by MixtralModel
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
...
@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralModel
.
load_weights
(
self
,
new_weights
.
items
())
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
...
...
@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
# Required by MixtralForCausalLM
self
.
model
=
GraniteMoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
...
...
@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/granitemoeshared.py
View file @
9c4ecf15
...
...
@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
from
.
import
mixtral
from
.granitemoe
import
GraniteMoeAttention
,
GraniteMoeMoE
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
class
GraniteMoeSharedMLP
(
nn
.
Module
):
...
...
@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Required by MixtralModel
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
...
@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralModel
.
load_weights
(
self
,
new_weights
.
items
())
class
GraniteMoeSharedForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
...
...
@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
GraniteMoeSharedModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
...
...
@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/grok1.py
View file @
9c4ecf15
...
...
@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
...
@@ -370,94 +372,6 @@ class Grok1Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Grok1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
output_multiplier_scale
=
getattr
(
config
,
"output_multiplier_scale"
,
DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
self
.
output_multiplier_scale
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
...
...
@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
"norm.scale"
in
name
:
name
=
name
.
replace
(
"scale"
,
"weight"
)
# Skip lm_head when tie_word_embeddings is True
if
"lm_head"
in
name
and
self
.
config
.
tie_word_embeddings
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Grok1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
output_multiplier_scale
=
getattr
(
config
,
"output_multiplier_scale"
,
DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
self
.
output_multiplier_scale
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
skip_prefixes
=
[
"rotary_emb.inv_freq"
]
# Skip lm_head when tie_word_embeddings is True
if
self
.
config
.
tie_word_embeddings
:
skip_prefixes
.
append
(
"lm_head"
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/h2ovl.py
View file @
9c4ecf15
...
...
@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features
=
IMG_CONTEXT
*
feature_size
repl_full
=
IMG_START
+
repl_features
+
IMG_END
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
return
PromptUpdateDetails
.
select_text
(
repl_full
,
IMG_CONTEXT
)
def
resolve_min_max_num
(
self
,
...
...
@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
**
kwargs
,
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
max_tokens_one_image
=
self
.
get_max_image_tokens
(
use_msac
=
None
)
if
mm_counts
.
get
(
"image"
,
0
)
<=
1
:
max_tokens_per_image
=
max_tokens_one_image
else
:
max_tokens_per_image
=
self
.
get_max_image_tokens
(
use_msac
=
False
)
return
{
"image"
:
max_tokens_per_image
}
def
get_num_image_tokens
(
self
,
*
,
...
...
@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
use_msac
=
use_msac
,
)
def
get_max_image_tokens
(
self
,
use_msac
:
Optional
[
bool
]
=
None
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
use_msac
=
use_msac
,
)
class
H2OVLMultiModalProcessor
(
InternVLMultiModalProcessor
[
H2OVLProcessingInfo
]
):
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment