Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
963 additions
and
748 deletions
+963
-748
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+30
-80
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+13
-19
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+49
-17
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+12
-17
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+20
-20
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+0
-3
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+54
-48
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+24
-10
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+17
-27
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+28
-21
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+36
-68
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+47
-42
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+37
-92
vllm/model_executor/models/glm4.py
vllm/model_executor/models/glm4.py
+313
-0
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+19
-23
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+76
-69
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+43
-36
vllm/model_executor/models/granitemoeshared.py
vllm/model_executor/models/granitemoeshared.py
+43
-36
vllm/model_executor/models/grok1.py
vllm/model_executor/models/grok1.py
+101
-96
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+1
-24
No files found.
vllm/model_executor/models/aya_vision.py
View file @
9c4ecf15
...
@@ -20,22 +20,21 @@ from vllm.jsontree import json_map_leaves
...
@@ -20,22 +20,21 @@ from vllm.jsontree import json_map_leaves
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
MultiModalFieldConfig
,
MultiModalFieldConfig
,
PromptReplacement
,
PromptUpdate
,
PromptReplacement
,
PromptUpdate
,
encode_token
s
)
PromptUpdateDetail
s
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
class
AyaVisionImagePixelInputs
(
TypedDict
):
class
AyaVisionImagePixelInputs
(
TypedDict
):
...
@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
...
@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
AyaVisionMultiModalProjector
(
nn
.
Module
):
class
AyaVisionMultiModalProjector
(
nn
.
Module
):
...
@@ -125,32 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
...
@@ -125,32 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def
get_image_processor
(
self
)
->
GotOcr2ImageProcessor
:
def
get_image_processor
(
self
)
->
GotOcr2ImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
return
self
.
get_hf_processor
().
image_processor
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
image_processor
=
hf_processor
.
image_processor
image_size
=
self
.
get_image_size_with_most_features
()
tokenizer
=
hf_processor
.
tokenizer
num_patches
=
self
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
image_string
=
hf_processor
.
_prompt_split_image
(
num_patches
)
x
=
encode_tokens
(
tokenizer
,
image_string
,
add_special_tokens
=
False
,
)
return
len
(
x
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
...
@@ -180,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
...
@@ -180,28 +146,29 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
class
AyaVisionDummyInputsBuilder
(
class
AyaVisionDummyInputsBuilder
(
BaseDummyInputsBuilder
[
AyaVisionProcessingInfo
]):
BaseDummyInputsBuilder
[
AyaVisionProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
self
,
num_images
=
mm_counts
.
get
(
"image"
,
0
)
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
processor
=
self
.
info
.
get_hf_processor
()
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
image_size
=
\
image_size
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
image_size
.
width
,
self
.
_get_dummy_images
(
width
=
image_size
.
width
,
height
=
image_size
.
height
,
height
=
image_size
.
height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
AyaVisionMultiModalProcessor
(
class
AyaVisionMultiModalProcessor
(
...
@@ -221,7 +188,6 @@ class AyaVisionMultiModalProcessor(
...
@@ -221,7 +188,6 @@ class AyaVisionMultiModalProcessor(
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_processor
=
hf_processor
.
image_processor
image_processor
=
hf_processor
.
image_processor
hf_config
=
self
.
info
.
get_hf_config
()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if
(
images
:
=
if
(
images
:
=
mm_data
.
get
(
"images"
))
is
not
None
and
'<image>'
in
prompt
:
mm_data
.
get
(
"images"
))
is
not
None
and
'<image>'
in
prompt
:
...
@@ -234,6 +200,7 @@ class AyaVisionMultiModalProcessor(
...
@@ -234,6 +200,7 @@ class AyaVisionMultiModalProcessor(
parsed_images
.
get_image_size
(
i
)
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
for
i
in
range
(
len
(
parsed_images
))
]
]
num_patches
=
[
num_patches
=
[
self
.
info
.
get_num_patches
(
self
.
info
.
get_num_patches
(
image_width
=
image_size
.
width
,
image_width
=
image_size
.
width
,
...
@@ -243,20 +210,6 @@ class AyaVisionMultiModalProcessor(
...
@@ -243,20 +210,6 @@ class AyaVisionMultiModalProcessor(
max_patches
=
image_processor
.
max_patches
)
max_patches
=
image_processor
.
max_patches
)
for
image_size
in
image_sizes
for
image_size
in
image_sizes
]
]
image_tokens_list
=
[
hf_processor
.
_prompt_split_image
(
num_patch
)
for
num_patch
in
num_patches
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_token_ids
=
[
tokenizer
.
encode
(
image_tokens
,
add_special_tokens
=
False
)
for
image_tokens
in
image_tokens_list
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
hf_config
.
image_token_index
for
image_repl_tokens
in
image_token_ids
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
processed_outputs
[
"num_patches"
]
=
torch
.
tensor
(
num_patches
)
processed_outputs
[
"num_patches"
]
=
torch
.
tensor
(
num_patches
)
return
processed_outputs
return
processed_outputs
...
@@ -271,7 +224,6 @@ class AyaVisionMultiModalProcessor(
...
@@ -271,7 +224,6 @@ class AyaVisionMultiModalProcessor(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -283,6 +235,7 @@ class AyaVisionMultiModalProcessor(
...
@@ -283,6 +235,7 @@ class AyaVisionMultiModalProcessor(
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
image_token
=
hf_processor
.
image_token
img_patch_token
=
hf_processor
.
img_patch_token
image_processor
=
hf_processor
.
image_processor
image_processor
=
hf_processor
.
image_processor
def
get_replacement
(
item_idx
:
int
):
def
get_replacement
(
item_idx
:
int
):
...
@@ -294,8 +247,11 @@ class AyaVisionMultiModalProcessor(
...
@@ -294,8 +247,11 @@ class AyaVisionMultiModalProcessor(
image_height
=
image_size
.
height
,
image_height
=
image_size
.
height
,
size
=
image_processor
.
size
,
size
=
image_processor
.
size
,
min_patches
=
image_processor
.
min_patches
,
min_patches
=
image_processor
.
min_patches
,
max_patches
=
image_processor
.
max_patches
)
max_patches
=
image_processor
.
max_patches
,
return
hf_processor
.
_prompt_split_image
(
num_patches
=
num_patches
)
)
repl
=
hf_processor
.
_prompt_split_image
(
num_patches
=
num_patches
)
return
PromptUpdateDetails
.
select_text
(
repl
,
img_patch_token
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -424,7 +380,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -424,7 +380,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
**
kwargs
:
object
)
->
Optional
[
AyaVisionImagePixelInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
AyaVisionImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Aya Vision does not support image_embeds."
assert
image_embeds
is
None
,
"Aya Vision does not support image_embeds."
...
@@ -436,30 +391,25 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -436,30 +391,25 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of num_patches. "
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
f
"Got type:
{
type
(
num_patches
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
AyaVisionImagePixelInputs
(
return
AyaVisionImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
num_patches
,
num_patches
=
num_patches
,
embed_is_patch
=
embed_is_patch
,
)
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
,
**
kwargs
)
return
scatter_patch_features
(
return
self
.
_process_image_input
(
image_input
,
**
kwargs
)
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -471,9 +421,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -471,9 +421,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
select_patch_features
(
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
)
,
placeholder_token_id
=
self
.
config
.
image_token_index
,
placeholder_token_id
=
self
.
config
.
image_token_index
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/bamba.py
View file @
9c4ecf15
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
MambaMixer2
,
extra_groups_for_head_shards
)
MambaMixer2
,
extra_groups_for_head_shards
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim
=
config
.
mamba_d_head
,
head_dim
=
config
.
mamba_d_head
,
rms_norm_eps
=
config
.
rms_norm_eps
,
rms_norm_eps
=
config
.
rms_norm_eps
,
activation
=
config
.
hidden_act
,
activation
=
config
.
hidden_act
,
chunk_size
=
config
.
mamba_chunk_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
feed_forward
=
BambaMLP
(
config
,
quant_config
=
quant_config
)
self
.
feed_forward
=
BambaMLP
(
config
,
quant_config
=
quant_config
)
...
@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
...
@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
,
sequence_idx
)
mamba2_metadata
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
...
@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
:
BambaConfig
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
...
@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
...
@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx
=
None
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
mamba2_metadata
=
prepare_mamba2_metadata
(
for
i
,
(
srt
,
end
)
in
enumerate
(
chunk_size
=
self
.
config
.
mamba_chunk_size
,
zip
(
input_ids
=
input_ids
,
attn_metadata
.
query_start_loc
,
attn_metadata
=
attn_metadata
,
attn_metadata
.
query_start_loc
[
1
:],
)
)):
seq_idx
[
srt
:
end
]
=
i
seq_idx
.
unsqueeze_
(
0
)
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
...
@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
...
@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
,
mamba_cache_params
=
layer_mamba_cache_params
,
sequence_idx
=
seq_idx
,
mamba2_metadata
=
mamba2_metadata
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/bert.py
View file @
9c4ecf15
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
from
vllm.model_executor.layers.pooler
import
(
CrossEncodingPooler
,
Pooler
,
PoolingType
)
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
...
@@ -38,19 +39,24 @@ class BertEmbedding(nn.Module):
self
.
size
=
config
.
hidden_size
self
.
size
=
config
.
hidden_size
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
position_embeddings
=
VocabParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
self
.
token_type_embeddings
=
VocabParallelEmbedding
(
config
.
type_vocab_size
,
config
.
hidden_size
)
config
.
type_vocab_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
self
.
position_embedding_type
=
config
.
position_embedding_type
self
.
position_embedding_type
=
config
.
position_embedding_type
if
self
.
position_embedding_type
!=
"absolute"
:
if
self
.
position_embedding_type
==
"absolute"
:
raise
ValueError
(
"Only 'absolute' position_embedding_type"
+
self
.
position_embeddings
=
VocabParallelEmbedding
(
" is supported"
)
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
position_ids
=
nn
.
Parameter
(
torch
.
empty
((
1
,
config
.
max_position_embeddings
)),
)
elif
self
.
position_embedding_type
==
"rotary"
:
self
.
position_embeddings
=
None
self
.
position_ids
=
None
else
:
raise
ValueError
(
"Only 'absolute' and 'rotary' "
+
"position_embedding_type is supported"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
...
@@ -64,9 +70,6 @@ class BertEmbedding(nn.Module):
# Input embeddings.
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
...
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
...
@@ -74,7 +77,12 @@ class BertEmbedding(nn.Module):
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
inputs_embeds
+
token_type_embeddings
if
self
.
position_embedding_type
==
"absolute"
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
+=
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
LayerNorm
(
embeddings
)
return
embeddings
return
embeddings
...
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
...
@@ -98,7 +106,10 @@ class BertPooler(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
...
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
...
@@ -107,16 +118,18 @@ class BertEncoder(nn.Module):
BertLayer
(
config
=
config
,
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
prefix
=
f
"
{
prefix
}
.layer.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
])
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
for
layer
in
self
.
layer
:
for
layer
in
self
.
layer
:
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
...
@@ -126,6 +139,7 @@ class BertLayer(nn.Module):
config
:
BertConfig
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
...
@@ -135,6 +149,7 @@ class BertLayer(nn.Module):
layer_norm_eps
=
config
.
layer_norm_eps
,
layer_norm_eps
=
config
.
layer_norm_eps
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.attention"
)
prefix
=
f
"
{
prefix
}
.attention"
)
self
.
intermediate
=
BertIntermediate
(
self
.
intermediate
=
BertIntermediate
(
...
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
...
@@ -150,8 +165,8 @@ class BertLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.output"
)
prefix
=
f
"
{
prefix
}
.output"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
):
attn_output
=
self
.
attention
(
hidden_states
)
attn_output
=
self
.
attention
(
positions
,
hidden_states
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
intermediate_output
=
self
.
intermediate
(
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
output
=
self
.
output
(
intermediate_output
,
attn_output
)
return
output
return
output
...
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
...
@@ -166,6 +181,7 @@ class BertAttention(nn.Module):
layer_norm_eps
:
float
,
layer_norm_eps
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
...
@@ -174,6 +190,7 @@ class BertAttention(nn.Module):
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.output"
)
prefix
=
f
"
{
prefix
}
.output"
)
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
self
.
output
=
BertSelfOutput
(
hidden_size
=
hidden_size
,
...
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
...
@@ -183,9 +200,10 @@ class BertAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
self_output
=
self
.
self
(
hidden_states
)
self_output
=
self
.
self
(
positions
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
return
self
.
output
(
self_output
,
hidden_states
)
...
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
...
@@ -197,6 +215,7 @@ class BertSelfAttention(nn.Module):
num_attention_heads
:
int
,
num_attention_heads
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
...
@@ -225,6 +244,11 @@ class BertSelfAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
if
rotary_kwargs
:
self
.
rotary_emb
=
get_rope
(
**
rotary_kwargs
)
else
:
self
.
rotary_emb
=
None
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
self
.
attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
scale
=
self
.
scaling
,
scale
=
self
.
scaling
,
...
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
...
@@ -236,10 +260,15 @@ class BertSelfAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
output
=
self
.
attn
(
q
,
k
,
v
)
output
=
self
.
attn
(
q
,
k
,
v
)
return
output
return
output
...
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
...
@@ -321,11 +350,13 @@ class BertModel(nn.Module, SupportsQuant):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
embedding_class
:
type
=
BertEmbedding
,
embedding_class
:
type
=
BertEmbedding
,
rotary_kwargs
:
Optional
[
dict
]
=
None
,
add_pooling_layer
:
bool
=
False
):
add_pooling_layer
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
embeddings
=
embedding_class
(
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
rotary_kwargs
=
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
)
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
...
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
...
@@ -347,7 +378,7 @@ class BertModel(nn.Module, SupportsQuant):
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
)
return
self
.
encoder
(
position_ids
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
...
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
self
.
_build_pooler
(
pooler_config
)
self
.
_pooler
=
self
.
_build_pooler
(
pooler_config
)
...
...
vllm/model_executor/models/blip2.py
View file @
9c4ecf15
...
@@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -15,12 +15,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIndexTargets
,
BaseProcessingInfo
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.blip
import
BlipVisionModel
from
.blip
import
BlipVisionModel
...
@@ -406,13 +407,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
...
@@ -406,13 +407,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
def
get_num_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
hf_config
=
self
.
get_hf_config
()
return
hf_config
.
num_query_tokens
return
hf_config
.
num_query_tokens
...
@@ -420,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
...
@@ -420,29 +414,27 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
class
Blip2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Blip2ProcessingInfo
]):
class
Blip2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Blip2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
hf_config
=
self
.
info
.
get_hf_config
()
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
max_image_size
=
vision_config
.
image_size
max_image_size
=
vision_config
.
image_size
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
max_image_size
,
self
.
_get_dummy_images
(
width
=
max_image_size
,
height
=
max_image_size
,
height
=
max_image_size
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
Blip2MultiModalProcessor
(
BaseMultiModalProcessor
[
Blip2ProcessingInfo
]):
class
Blip2MultiModalProcessor
(
BaseMultiModalProcessor
[
Blip2ProcessingInfo
]):
...
@@ -627,6 +619,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -627,6 +619,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return
self
.
language_projection
(
query_output
)
return
self
.
language_projection
(
query_output
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/chameleon.py
View file @
9c4ecf15
...
@@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -30,12 +30,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
...
@@ -64,13 +65,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
...
@@ -64,13 +65,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
def
get_num_image_tokens
(
self
)
->
int
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
return
processor
.
image_seq_length
return
processor
.
image_seq_length
...
@@ -79,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
...
@@ -79,28 +73,31 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
class
ChameleonDummyInputsBuilder
(
class
ChameleonDummyInputsBuilder
(
BaseDummyInputsBuilder
[
ChameleonProcessingInfo
]):
BaseDummyInputsBuilder
[
ChameleonProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
config
=
self
.
info
.
get_hf_config
()
config
=
self
.
info
.
get_hf_config
()
width
=
height
=
config
.
vq_config
.
resolution
width
=
height
=
config
.
vq_config
.
resolution
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
width
,
self
.
_get_dummy_images
(
width
=
width
,
height
=
height
,
height
=
height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
"<image>"
*
num_images
,
mm_data
=
mm_data
,
)
class
ChameleonMultiModalProcessor
(
class
ChameleonMultiModalProcessor
(
BaseMultiModalProcessor
[
ChameleonProcessingInfo
]):
BaseMultiModalProcessor
[
ChameleonProcessingInfo
]):
...
@@ -162,9 +159,9 @@ class ChameleonMultiModalProcessor(
...
@@ -162,9 +159,9 @@ class ChameleonMultiModalProcessor(
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
[
image_token_id
],
target
=
[
image_token_id
],
replacement
=
PromptUpdateDetails
(
replacement
=
PromptUpdateDetails
.
select_token_id
(
full
=
(
[
image_start_id
]
+
image_tokens
+
[
image_end_id
]
)
,
[
image_start_id
]
+
image_tokens
+
[
image_end_id
],
features
=
image_token
s
,
embed_token_id
=
image_token
_id
,
),
),
)
)
]
]
...
@@ -988,6 +985,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -988,6 +985,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/clip.py
View file @
9c4ecf15
...
@@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
...
@@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
)
->
int
:
)
->
int
:
return
self
.
get_patch_grid_length
()
**
2
+
1
return
self
.
get_patch_grid_length
()
**
2
+
1
def
get_max_image_tokens
(
self
)
->
int
:
return
self
.
get_patch_grid_length
()
**
2
+
1
def
get_image_size
(
self
)
->
int
:
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
return
self
.
vision_config
.
image_size
...
...
vllm/model_executor/models/deepseek.py
View file @
9c4ecf15
...
@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
...
@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
...
@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
self
)
# (param_name, shard_name, shard_id)
return
loader
.
load_weights
(
weights
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
\ No newline at end of file
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/deepseek_v2.py
View file @
9c4ecf15
...
@@ -163,14 +163,16 @@ class DeepseekV2MoE(nn.Module):
...
@@ -163,14 +163,16 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
else
:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
else
:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
@@ -502,6 +504,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -502,6 +504,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
# with the layer's index.
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
self
.
layer_idx
=
layer_idx
if
model_config
.
use_mla
:
if
model_config
.
use_mla
:
attn_cls
=
DeepseekV2MLAAttention
attn_cls
=
DeepseekV2MLAAttention
else
:
else
:
...
@@ -564,19 +567,30 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -564,19 +567,30 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
)
)
# Fully Connected
if
hidden_states
.
dtype
==
torch
.
float16
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
\
# Fix FP16 overflow
hidden_states
.
dtype
==
torch
.
float16
:
# We scale both hidden_states and residual before
#
This is a special case to avoid FP16 overflow
#
rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
if
isinstance
(
self
.
mlp
,
# This is a special case to avoid FP16 overflow
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
hidden_states
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
vllm/model_executor/models/deepseek_vl2.py
View file @
9c4ecf15
...
@@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -19,14 +19,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModal
FieldConfig
,
MultiModal
Kwargs
,
from
vllm.multimodal.inputs
import
(
MultiModal
DataDict
,
MultiModal
FieldConfig
,
NestedTensors
)
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
MlpProjectorConfig
,
MlpProjectorConfig
,
...
@@ -168,47 +168,34 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
...
@@ -168,47 +168,34 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width
=
x
[
1
],
image_height
=
x
[
0
]))
image_width
=
x
[
1
],
image_height
=
x
[
0
]))
return
ImageSize
(
width
=
width
,
height
=
height
)
return
ImageSize
(
width
=
width
,
height
=
height
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
max_image_size
=
self
.
get_image_size_with_most_features
()
max_image_tokens
=
self
.
get_num_image_tokens
(
image_height
=
max_image_size
.
height
,
image_width
=
max_image_size
.
width
,
cropping
=
num_images
<=
2
)
return
{
"image"
:
max_image_tokens
}
class
DeepseekVL2DummyInputsBuilder
(
class
DeepseekVL2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
DeepseekVL2ProcessingInfo
]):
BaseDummyInputsBuilder
[
DeepseekVL2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
max_image_size
=
self
.
info
.
get_image_size_with_most_features
()
max_image_size
=
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
max_image_size
.
width
,
self
.
_get_dummy_images
(
width
=
max_image_size
.
width
,
height
=
max_image_size
.
height
,
height
=
max_image_size
.
height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
DeepseekVL2MultiModalProcessor
(
class
DeepseekVL2MultiModalProcessor
(
BaseMultiModalProcessor
[
DeepseekVL2ProcessingInfo
]):
BaseMultiModalProcessor
[
DeepseekVL2ProcessingInfo
]):
...
@@ -604,6 +591,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -604,6 +591,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
self
.
_pixel_values_to_embedding
(
return
self
.
_pixel_values_to_embedding
(
pixel_values
=
pixel_values
,
images_spatial_crop
=
images_spatial_crop
)
pixel_values
=
pixel_values
,
images_spatial_crop
=
images_spatial_crop
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/florence2.py
View file @
9c4ecf15
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BartTokenizer
,
BatchFeature
,
PretrainedConfig
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
@@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
...
@@ -21,13 +21,14 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartScaledWordEmbedding
)
BartScaledWordEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.parse
import
MultiModalDataDict
,
MultiModalDataItems
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptIndexTargets
,
PromptInsertion
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
...
@@ -764,42 +765,33 @@ class Florence2ProcessingInfo(BaseProcessingInfo):
...
@@ -764,42 +765,33 @@ class Florence2ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
def
get_
max
_image_tokens
(
self
)
->
int
:
def
get_
num
_image_tokens
(
self
)
->
int
:
processor_config
=
self
.
ctx
.
get_hf_image_processor_config
()
processor_config
=
self
.
ctx
.
get_hf_image_processor_config
()
return
processor_config
[
"image_seq_length"
]
return
processor_config
[
"image_seq_length"
]
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
class
Florence2DummyInputsBuilder
(
class
Florence2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Florence2ProcessingInfo
]):
BaseDummyInputsBuilder
[
Florence2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
=
target_height
=
self
.
info
.
get_hf_config
().
projection_dim
target_width
=
target_height
=
self
.
info
.
get_hf_config
().
projection_dim
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
Florence2MultiModalProcessor
(
class
Florence2MultiModalProcessor
(
EncDecMultiModalProcessor
[
Florence2ProcessingInfo
]):
EncDecMultiModalProcessor
[
Florence2ProcessingInfo
]):
...
@@ -826,6 +818,18 @@ class Florence2MultiModalProcessor(
...
@@ -826,6 +818,18 @@ class Florence2MultiModalProcessor(
)
->
Union
[
str
,
list
[
int
]]:
)
->
Union
[
str
,
list
[
int
]]:
return
[
self
.
info
.
get_hf_config
().
eos_token_id
]
return
[
self
.
info
.
get_hf_config
().
eos_token_id
]
def
_apply_hf_processor_tokens_only
(
self
,
prompt_tokens
:
list
[
int
],
)
->
list
[
int
]:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
:
BartTokenizer
=
hf_processor
.
tokenizer
prompt_text
=
tokenizer
.
decode
(
prompt_tokens
)
# convert task tokens to prompt
prompt_text
=
hf_processor
.
_construct_prompts
([
prompt_text
])[
0
]
prompt_tokens
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)
return
prompt_tokens
def
_call_hf_processor
(
def
_call_hf_processor
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
...
@@ -859,7 +863,7 @@ class Florence2MultiModalProcessor(
...
@@ -859,7 +863,7 @@ class Florence2MultiModalProcessor(
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
hf_config
=
self
.
info
.
get_hf_config
()
pad_token_id
=
hf_config
.
pad_token_id
pad_token_id
=
hf_config
.
pad_token_id
num_image_tokens
=
self
.
info
.
get_
max
_image_tokens
()
num_image_tokens
=
self
.
info
.
get_
num
_image_tokens
()
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
return
[
return
[
...
@@ -1038,6 +1042,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1038,6 +1042,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values
=
image_input
[
"data"
]
pixel_values
=
image_input
[
"data"
]
return
self
.
_encode_image
(
pixel_values
)
return
self
.
_encode_image
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/fuyu.py
View file @
9c4ecf15
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
""" PyTorch Fuyu model."""
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -31,19 +31,19 @@ from vllm.model_executor.layers.sampler import SamplerOutput
...
@@ -31,19 +31,19 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# Cannot find the following 2 numbers from hf config.
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID
=
71011
_IMAGE_TOKEN_ID
=
71011
...
@@ -66,14 +66,6 @@ class FuyuImagePatchInputs(TypedDict):
...
@@ -66,14 +66,6 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`.
flattened just like `flat_data`.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
FuyuProcessingInfo
(
BaseProcessingInfo
):
class
FuyuProcessingInfo
(
BaseProcessingInfo
):
...
@@ -89,21 +81,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
...
@@ -89,21 +81,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
max_ncols
,
max_nrows
=
self
.
get_image_feature_grid_size
(
image_width
=
target_width
,
image_height
=
target_height
,
)
max_image_tokens
=
(
max_ncols
+
1
)
*
max_nrows
return
{
"image"
:
max_image_tokens
}
def
get_image_feature_grid_size
(
def
get_image_feature_grid_size
(
self
,
self
,
*
,
*
,
...
@@ -128,6 +105,19 @@ class FuyuProcessingInfo(BaseProcessingInfo):
...
@@ -128,6 +105,19 @@ class FuyuProcessingInfo(BaseProcessingInfo):
nrows
=
math
.
ceil
(
image_height
/
patch_height
)
nrows
=
math
.
ceil
(
image_height
/
patch_height
)
return
ncols
,
nrows
return
ncols
,
nrows
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
ncols
,
nrows
=
self
.
get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
)
return
ncols
*
nrows
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
image_processor
=
self
.
get_image_processor
()
return
ImageSize
(
width
=
image_processor
.
size
[
"width"
],
return
ImageSize
(
width
=
image_processor
.
size
[
"width"
],
...
@@ -136,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo):
...
@@ -136,27 +126,25 @@ class FuyuProcessingInfo(BaseProcessingInfo):
class
FuyuDummyInputsBuilder
(
BaseDummyInputsBuilder
[
FuyuProcessingInfo
]):
class
FuyuDummyInputsBuilder
(
BaseDummyInputsBuilder
[
FuyuProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
FuyuMultiModalProcessor
(
BaseMultiModalProcessor
[
FuyuProcessingInfo
]):
class
FuyuMultiModalProcessor
(
BaseMultiModalProcessor
[
FuyuProcessingInfo
]):
...
@@ -192,19 +180,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -192,19 +180,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs
[
"image_patches"
]
=
image_patches
[
0
]
processed_outputs
[
"image_patches"
]
=
image_patches
[
0
]
# get patch grid size for each image
embed_is_patch
=
[]
for
image
in
images
:
ncols
,
nrows
=
self
.
info
.
get_image_feature_grid_size
(
image_width
=
image
.
width
,
image_height
=
image
.
height
,
)
mask
=
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
embed_is_patch
.
append
(
mask
)
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
def
_apply_hf_processor_tokens_only
(
def
_apply_hf_processor_tokens_only
(
...
@@ -224,8 +199,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -224,8 +199,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs
:
BatchFeature
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
))
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
))
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
...
@@ -252,9 +226,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -252,9 +226,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens
=
([
_IMAGE_TOKEN_ID
]
*
ncols
+
image_tokens
=
([
_IMAGE_TOKEN_ID
]
*
ncols
+
[
_NEWLINE_TOKEN_ID
])
*
nrows
[
_NEWLINE_TOKEN_ID
])
*
nrows
return
PromptUpdateDetails
(
return
PromptUpdateDetails
.
select_token_id
(
full
=
image_tokens
+
[
bos_token_id
],
image_tokens
+
[
bos_token_id
],
features
=
image_tokens
,
embed_token_id
=
_IMAGE_TOKEN_ID
,
)
)
return
[
return
[
...
@@ -329,20 +303,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -329,20 +303,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of image patches. "
raise
ValueError
(
"Incorrect type of image patches. "
f
"Got type:
{
type
(
image_patches
)
}
"
)
f
"Got type:
{
type
(
image_patches
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
image_patches_flat
=
flatten_bn
(
image_patches
)
image_patches_flat
=
flatten_bn
(
image_patches
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
FuyuImagePatchInputs
(
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
type
=
"image_patches"
,
flat_data
=
self
.
_validate_pixel_values
(
flat_data
=
self
.
_validate_pixel_values
(
flatten_bn
(
image_patches_flat
,
concat
=
True
)),
flatten_bn
(
image_patches_flat
,
concat
=
True
)),
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
embed_is_patch
=
embed_is_patch
,
)
)
return
None
return
None
...
@@ -358,18 +325,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -358,18 +325,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -379,8 +344,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -379,8 +344,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
input_ids
,
select_patch_features
(
multimodal_embeddings
),
_IMAGE_TOKEN_ID
)
inputs_embeds
,
multimodal_embeddings
,
_IMAGE_TOKEN_ID
,
)
return
inputs_embeds
return
inputs_embeds
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/gemma.py
View file @
9c4ecf15
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
...
@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
# (param_name, shard_name, shard_id)
self
,
(
"qkv_proj"
,
"q_proj"
,
"q"
),
skip_prefixes
=
([
"lm_head."
]
(
"qkv_proj"
,
"k_proj"
,
"k"
),
if
self
.
config
.
tie_word_embeddings
else
None
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
return
loader
.
load_weights
(
weights
)
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/gemma3_mm.py
View file @
9c4ecf15
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
# yapf: disable
# yapf: disable
...
@@ -25,10 +26,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -25,10 +26,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptTargetMatch
,
PromptReplacement
,
PromptTargetMatch
,
PromptUpdate
,
PromptUpdateDetails
,
PromptUpdate
,
PromptUpdateDetails
,
encode_tokens
,
find_mm_placeholders
,
find_mm_placeholders
,
replace_token_matches
)
replace_token_matches
)
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -36,7 +37,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...
@@ -36,7 +37,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -54,14 +54,6 @@ class Gemma3ImagePixelInputs(TypedDict):
...
@@ -54,14 +54,6 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
...
@@ -77,13 +69,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -77,13 +69,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
_resolve_image_kwargs
(
def
_resolve_image_kwargs
(
self
,
self
,
processor
:
Gemma3Processor
,
processor
:
Gemma3Processor
,
...
@@ -183,7 +168,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -183,7 +168,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if
processor
is
None
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
image
_token
=
processor
.
boi_token
boi
_token
=
processor
.
boi_token
num_crops
=
self
.
get_num_crops
(
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
image_width
=
image_width
,
...
@@ -192,19 +177,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -192,19 +177,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
)
)
if
num_crops
==
0
:
if
num_crops
==
0
:
image_text
=
image
_token
image_text
=
boi
_token
else
:
else
:
crops_image_tokens
=
" "
.
join
(
image_token
crops_image_tokens
=
" "
.
join
(
boi_token
for
_
in
range
(
num_crops
))
for
_
in
range
(
num_crops
))
image_text
=
(
image_text
=
(
f
"Here is the original image
{
image
_token
}
and here are some "
f
"Here is the original image
{
boi
_token
}
and here are some "
f
"crops to help you see better
{
crops_image_tokens
}
"
)
f
"crops to help you see better
{
crops_image_tokens
}
"
)
repl_full
=
image_text
.
replace
(
image
_token
,
repl_full
=
image_text
.
replace
(
boi
_token
,
processor
.
full_image_sequence
)
processor
.
full_image_sequence
)
repl_features
=
repl_full
.
strip
(
"
\n
"
)
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
tokenizer
=
processor
.
tokenizer
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
tokenizer
.
image_token
]
return
PromptUpdateDetails
.
select_token_id
(
repl_full
,
image_token_id
)
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
...
@@ -213,19 +200,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -213,19 +200,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
image_height
:
int
,
processor
:
Optional
[
Gemma3Processor
],
processor
:
Optional
[
Gemma3Processor
],
)
->
int
:
)
->
int
:
tokenizer
=
self
.
get_tokenizer
()
if
processor
is
None
:
image_repl
=
self
.
get_image_repl
(
processor
=
self
.
get_hf_processor
()
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
image_width
=
image_width
,
image_height
=
image_height
,
image_height
=
image_height
,
processor
=
processor
,
processor
=
processor
,
)
)
image_seq_len
=
processor
.
image_seq_length
image_repl_tokens
=
encode_tokens
(
return
(
num_crops
+
1
)
*
image_seq_len
tokenizer
,
image_repl
.
features
,
add_special_tokens
=
False
,
)
return
len
(
image_repl_tokens
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
...
@@ -237,43 +222,34 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -237,43 +222,34 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
# Result in the max possible feature size (h:w = max_num_crops:1)
# Result in the max possible feature size (h:w = max_num_crops:1)
return
ImageSize
(
height
=
50
*
max_num_crops
,
width
=
50
)
return
ImageSize
(
height
=
50
*
max_num_crops
,
width
=
50
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
class
Gemma3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Gemma3ProcessingInfo
]):
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
class
Gemma3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Gemma3ProcessingInfo
]):
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
boi_token
return
image_token
*
num_images
def
get_dummy_
processor_inputs
(
def
get_dummy_
mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
boi_token
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
Gemma3MultiModalProcessor
(
BaseMultiModalProcessor
[
Gemma3ProcessingInfo
]):
class
Gemma3MultiModalProcessor
(
BaseMultiModalProcessor
[
Gemma3ProcessingInfo
]):
...
@@ -301,28 +277,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -301,28 +277,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
]
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_repl_features
=
[
self
.
info
.
get_image_repl
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
).
features
for
size
in
image_sizes
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
tokenizer
.
image_token
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
image_token_id
for
image_repl_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
num_crops
=
[
num_crops
=
[
self
.
info
.
get_num_crops
(
image_width
=
size
.
width
,
self
.
info
.
get_num_crops
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
image_height
=
size
.
height
,
...
@@ -344,7 +298,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -344,7 +298,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
+
1
),
"image"
,
num_crops
+
1
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -454,6 +407,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -454,6 +407,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx
=
p
.
item_idx
,
item_idx
=
p
.
item_idx
,
start_idx
=
repl_orig_idxs
[
p
.
start_idx
],
start_idx
=
repl_orig_idxs
[
p
.
start_idx
],
tokens
=
p
.
tokens
,
tokens
=
p
.
tokens
,
is_embed
=
p
.
is_embed
,
)
for
p
in
placeholders
)
for
p
in
placeholders
]
]
for
modality
,
placeholders
in
repls
.
items
()
for
modality
,
placeholders
in
repls
.
items
()
...
@@ -572,7 +526,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -572,7 +526,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
if
pixel_values
is
None
:
if
pixel_values
is
None
:
...
@@ -586,19 +539,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -586,19 +539,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise
ValueError
(
"Incorrect type of num_crops. "
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
f
"Got type:
{
type
(
num_crops
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
Gemma3ImagePixelInputs
(
return
Gemma3ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
num_crops
+
1
,
num_patches
=
num_crops
+
1
,
embed_is_patch
=
embed_is_patch
,
)
)
def
_image_pixels_to_features
(
def
_image_pixels_to_features
(
...
@@ -629,18 +576,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -629,18 +576,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e
.
flatten
(
0
,
1
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())
e
.
flatten
(
0
,
1
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())
]
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -652,7 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -652,7 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
self
.
config
.
image_token_index
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/glm4.py
0 → 100644
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Zhipu AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Glm4Config
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.llama
import
LlamaMLP
as
Glm4MLP
from
.llama
import
LlamaModel
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
class
Glm4Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4Config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
Optional
[
int
]
=
None
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
0.5
)
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
rotary_dim
=
int
(
partial_rotary_factor
*
self
.
head_dim
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
partial_rotary_factor
=
partial_rotary_factor
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
attn_type
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Glm4DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Glm4Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
self_attn
=
Glm4Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
qkv_bias
=
getattr
(
config
,
'attention_bias'
,
False
),
head_dim
=
getattr
(
config
,
'head_dim'
,
None
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_type
=
AttentionType
.
DECODER
,
)
self
.
mlp
=
Glm4MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_self_attn_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_mlp_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
=
self
.
post_self_attn_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_mlp_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
ALL_DECODER_LAYER_TYPES
=
{
"attention"
:
Glm4DecoderLayer
,
}
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
})
class
Glm4Model
(
LlamaModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
Glm4DecoderLayer
)
class
Glm4ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Glm4Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/glm4v.py
View file @
9c4ecf15
...
@@ -12,7 +12,7 @@ from torch import nn
...
@@ -12,7 +12,7 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
PreTrainedTokenizer
,
TensorType
from
transformers
import
BatchFeature
,
PreTrainedTokenizer
,
TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
transformers.tokenization_utils_base
import
TextInput
...
@@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -28,13 +28,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BatchFeature
,
BaseProcessingInfo
,
PromptReplacement
,
MultiModalFieldConfig
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
...
@@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
...
@@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_num_image_feature_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
def
get_num_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
...
@@ -454,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
...
@@ -454,31 +447,31 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
class
GLM4VDummyInputsBuilder
(
BaseDummyInputsBuilder
[
GLM4VProcessingInfo
]):
class
GLM4VDummyInputsBuilder
(
BaseDummyInputsBuilder
[
GLM4VProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
base_text
=
"<|begin_of_image|><|endoftext|><|end_of_image|>"
return
base_text
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
hf_config
=
self
.
info
.
get_hf_config
()
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
target_width
=
target_height
=
vision_config
[
"image_size"
]
target_width
=
target_height
=
vision_config
[
"image_size"
]
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
base_text
=
"<|begin_of_image|><|endoftext|><|end_of_image|>"
return
ProcessorInputs
(
prompt_text
=
base_text
*
num_images
,
mm_data
=
mm_data
,
)
class
GLM4VMultiModalProcessor
(
BaseMultiModalProcessor
[
GLM4VProcessingInfo
]):
class
GLM4VMultiModalProcessor
(
BaseMultiModalProcessor
[
GLM4VProcessingInfo
]):
...
@@ -596,6 +589,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
...
@@ -596,6 +589,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return
self
.
transformer
.
vision
(
pixel_values
)
return
self
.
transformer
.
vision
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
transformer
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/granite.py
View file @
9c4ecf15
...
@@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
)
make_layers
,
maybe_prefix
)
class
GraniteMLP
(
nn
.
Module
):
class
GraniteMLP
(
nn
.
Module
):
...
@@ -260,6 +260,7 @@ class GraniteModel(nn.Module):
...
@@ -260,6 +260,7 @@ class GraniteModel(nn.Module):
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
@@ -321,6 +322,65 @@ class GraniteModel(nn.Module):
...
@@ -321,6 +322,65 @@ class GraniteModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GraniteForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
GraniteForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
skip_prefixes
=
[
# (param_name, shard_name, shard_id)
"rotary_emb.inv_freq"
,
(
".qkv_proj"
,
".q_proj"
,
"q"
),
# Models trained using ColossalAI may include these tensors in
(
".qkv_proj"
,
".k_proj"
,
"k"
),
# the checkpoint. Skip them.
(
".qkv_proj"
,
".v_proj"
,
"v"
),
"rotary_emb.cos_cached"
,
(
".gate_up_proj"
,
".gate_proj"
,
0
),
"rotary_emb.sin_cached"
,
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
# With tie_word_embeddings, we can skip lm_head.weight
loaded_params
:
Set
[
str
]
=
set
()
# The weight might appear unnecessarily in the files if the model is
for
name
,
loaded_weight
in
weights
:
# processed with quantization, LoRA, fine-tuning, etc.
if
"rotary_emb.inv_freq"
in
name
:
if
self
.
config
.
tie_word_embeddings
:
continue
skip_prefixes
.
append
(
"lm_head.weight"
)
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
# Models trained using ColossalAI may include these tensors in
return
loader
.
load_weights
(
weights
)
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/granitemoe.py
View file @
9c4ecf15
...
@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from
.
import
mixtral
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
class
GraniteMoeMoE
(
nn
.
Module
):
class
GraniteMoeMoE
(
nn
.
Module
):
...
@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
...
@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Required by MixtralModel
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
...
@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralModel
.
load_weights
(
self
,
new_weights
.
items
())
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
...
@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
# Required by MixtralForCausalLM
self
.
model
=
GraniteMoeModel
(
vllm_config
=
vllm_config
,
self
.
model
=
GraniteMoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
...
@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
loader
=
AutoWeightsLoader
(
for
n
,
p
in
weights
:
self
,
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
skip_prefixes
=
([
"lm_head."
]
for
e
in
range
(
p
.
size
(
0
)):
if
self
.
config
.
tie_word_embeddings
else
None
),
w1_name
=
n
.
replace
(
)
'.block_sparse_moe.input_linear.weight'
,
return
loader
.
load_weights
(
weights
)
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
vllm/model_executor/models/granitemoeshared.py
View file @
9c4ecf15
...
@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
from
.
import
mixtral
from
.
import
mixtral
from
.granitemoe
import
GraniteMoeAttention
,
GraniteMoeMoE
from
.granitemoe
import
GraniteMoeAttention
,
GraniteMoeMoE
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
make_layers
,
maybe_prefix
class
GraniteMoeSharedMLP
(
nn
.
Module
):
class
GraniteMoeSharedMLP
(
nn
.
Module
):
...
@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
...
@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Required by MixtralModel
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
...
@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralModel
.
load_weights
(
self
,
new_weights
.
items
())
class
GraniteMoeSharedForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
GraniteMoeSharedForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
...
@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
GraniteMoeSharedModel
(
vllm_config
=
vllm_config
,
self
.
model
=
GraniteMoeSharedModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
=
maybe_prefix
(
...
@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
loader
=
AutoWeightsLoader
(
for
n
,
p
in
weights
:
self
,
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
skip_prefixes
=
([
"lm_head."
]
for
e
in
range
(
p
.
size
(
0
)):
if
self
.
config
.
tie_word_embeddings
else
None
),
w1_name
=
n
.
replace
(
)
'.block_sparse_moe.input_linear.weight'
,
return
loader
.
load_weights
(
weights
)
f
".block_sparse_moe.experts.
{
e
}
.w1.weight"
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w3.weight"
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
f
".block_sparse_moe.experts.
{
e
}
.w2.weight"
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
return
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
vllm/model_executor/models/grok1.py
View file @
9c4ecf15
...
@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
...
@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
@@ -370,94 +372,6 @@ class Grok1Model(nn.Module):
...
@@ -370,94 +372,6 @@ class Grok1Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
class
Grok1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
output_multiplier_scale
=
getattr
(
config
,
"output_multiplier_scale"
,
DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
self
.
output_multiplier_scale
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
# Loading kv cache quantization scales
...
@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
"norm.scale"
in
name
:
if
"norm.scale"
in
name
:
name
=
name
.
replace
(
"scale"
,
"weight"
)
name
=
name
.
replace
(
"scale"
,
"weight"
)
# Skip lm_head when tie_word_embeddings is True
if
"lm_head"
in
name
and
self
.
config
.
tie_word_embeddings
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
return
loaded_params
return
loaded_params
class
Grok1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
output_multiplier_scale
=
getattr
(
config
,
"output_multiplier_scale"
,
DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
self
.
output_multiplier_scale
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
skip_prefixes
=
[
"rotary_emb.inv_freq"
]
# Skip lm_head when tie_word_embeddings is True
if
self
.
config
.
tie_word_embeddings
:
skip_prefixes
.
append
(
"lm_head"
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/h2ovl.py
View file @
9c4ecf15
...
@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
...
@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features
=
IMG_CONTEXT
*
feature_size
repl_features
=
IMG_CONTEXT
*
feature_size
repl_full
=
IMG_START
+
repl_features
+
IMG_END
repl_full
=
IMG_START
+
repl_features
+
IMG_END
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
return
PromptUpdateDetails
.
select_text
(
repl_full
,
IMG_CONTEXT
)
def
resolve_min_max_num
(
def
resolve_min_max_num
(
self
,
self
,
...
@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
...
@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
**
kwargs
,
**
kwargs
,
)
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
max_tokens_one_image
=
self
.
get_max_image_tokens
(
use_msac
=
None
)
if
mm_counts
.
get
(
"image"
,
0
)
<=
1
:
max_tokens_per_image
=
max_tokens_one_image
else
:
max_tokens_per_image
=
self
.
get_max_image_tokens
(
use_msac
=
False
)
return
{
"image"
:
max_tokens_per_image
}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
*
,
*
,
...
@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
...
@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
use_msac
=
use_msac
,
use_msac
=
use_msac
,
)
)
def
get_max_image_tokens
(
self
,
use_msac
:
Optional
[
bool
]
=
None
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
use_msac
=
use_msac
,
)
class
H2OVLMultiModalProcessor
(
InternVLMultiModalProcessor
[
H2OVLProcessingInfo
]
class
H2OVLMultiModalProcessor
(
InternVLMultiModalProcessor
[
H2OVLProcessingInfo
]
):
):
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment