Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1282bd81
Unverified
Commit
1282bd81
authored
Jun 03, 2025
by
汪志鹏
Committed by
GitHub
Jun 03, 2025
Browse files
Add tarsier model support (#18985)
Signed-off-by:
汪志鹏
<
wangzhipeng628@gmail.com
>
parent
bdce64f2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
689 additions
and
0 deletions
+689
-0
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+20
-0
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+21
-0
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+1
-0
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/tarsier.py
vllm/model_executor/models/tarsier.py
+643
-0
No files found.
docs/models/supported_models.md
View file @
1282bd81
...
...
@@ -550,6 +550,7 @@ Specified using `--task generate`.
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-7B`
| | ✅︎ | ✅︎
\*
|
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
|
`TarsierForConditionalGeneration`
| Tarsier | T + I
<sup>
E+
</sup>
|
`omni-search/Tarsier-7b`
,
`omni-search/Tarsier-34b`
| | ✅︎ | ✅︎ |
<sup>
^
</sup>
You need to set the architecture name via
`--hf-overrides`
to match the one in vLLM.
• For example, to use DeepSeek-VL2 series models:
...
...
examples/offline_inference/vision_language.py
View file @
1282bd81
...
...
@@ -333,6 +333,25 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
)
# omni-research/Tarsier-7b
def
run_tarsier
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
model_name
=
"omni-research/Tarsier-7b"
engine_args
=
EngineArgs
(
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
modality
:
1
},
)
prompts
=
[(
f
"USER: <image>
\n
{
question
}
ASSISTANT:"
)
for
question
in
questions
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompts
=
prompts
,
)
# InternVL
def
run_internvl
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
model_name
=
"OpenGVLab/InternVL3-2B"
...
...
@@ -1091,6 +1110,7 @@ model_example_map = {
"qwen2_5_omni"
:
run_qwen2_5_omni
,
"skywork_chat"
:
run_skyworkr1v
,
"smolvlm"
:
run_smolvlm
,
"tarsier"
:
run_tarsier
,
}
...
...
examples/offline_inference/vision_language_multi_image.py
View file @
1282bd81
...
...
@@ -691,6 +691,26 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def
load_tarsier
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
model_name
=
"omni-research/Tarsier-7b"
engine_args
=
EngineArgs
(
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
prompt
=
f
"USER:
{
'<image>'
*
len
(
image_urls
)
}
\n
{
question
}
\n
ASSISTANT:"
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompt
=
prompt
,
image_data
=
image_data
,
)
model_example_map
=
{
"aria"
:
load_aria
,
"aya_vision"
:
load_aya_vision
,
...
...
@@ -712,6 +732,7 @@ model_example_map = {
"qwen2_vl"
:
load_qwen2_vl
,
"qwen2_5_vl"
:
load_qwen2_5_vl
,
"smolvlm"
:
load_smolvlm
,
"tarsier"
:
load_tarsier
,
}
...
...
tests/models/multimodal/processing/test_common.py
View file @
1282bd81
...
...
@@ -282,6 +282,7 @@ def _test_processing_correctness_one(
"Skywork/Skywork-R1V-38B"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"openai/whisper-large-v3"
,
"omni-research/Tarsier-7b"
,
])
@
pytest
.
mark
.
parametrize
(
"hit_rate"
,
[
0.3
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"num_batches"
,
[
32
])
...
...
tests/models/registry.py
View file @
1282bd81
...
...
@@ -406,6 +406,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"SmolVLMForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
trust_remote_code
=
True
),
"TarsierForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier-7b"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"TarsierForConditionalGeneration"
]}),
# noqa: E501
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
...
...
vllm/model_executor/models/registry.py
View file @
1282bd81
...
...
@@ -211,6 +211,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"Phi4MMForCausalLM"
:
(
"phi4mm"
,
"Phi4MMForCausalLM"
),
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
# [Encoder-decoder]
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
...
...
vllm/model_executor/models/tarsier.py
0 → 100644
View file @
1282bd81
# SPDX-License-Identifier: Apache-2.0
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
(
Final
,
Literal
,
Optional
,
Protocol
,
TypedDict
,
TypeVar
,
Union
,
cast
)
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
CLIPVisionConfig
from
transformers
import
LlavaConfig
as
HfLlavaConfig
from
transformers
import
PretrainedConfig
,
SiglipVisionConfig
from
transformers.image_utils
import
ImageInput
,
get_image_size
,
to_numpy_array
from
transformers.models.llava
import
LlavaProcessor
from
transformers.processing_utils
import
(
ProcessingKwargs
,
Unpack
,
_validate_images_text_input_order
)
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputProcessingContext
from
vllm.jsontree
import
json_map_leaves
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.llava
import
LlavaDummyInputsBuilder
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.clip
import
CLIPVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
,
get_vision_encoder_info
class
TarsierImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
class
TarsierImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
TarsierImageInputs
=
Union
[
TarsierImagePixelInputs
,
TarsierImageEmbeddingInputs
]
class
TarsierHfConfig
(
Protocol
):
# Based on the Tarsier's LlavaConfig
vision_config
:
Final
[
PretrainedConfig
]
text_config
:
Final
[
PretrainedConfig
]
# Added from Tarsier's LlavaConfig
image_token_index
:
Final
[
int
]
vision_feature_select_strategy
:
Final
[
str
]
vision_feature_layer
:
Final
[
Union
[
int
,
list
[
int
]]]
projector_hidden_act
:
Final
[
str
]
image_newline_idx
:
Final
[
int
]
image_new_idx
:
Final
[
int
]
multimodal_projector_bias
:
bool
=
True
class
TarsierProcessorKwargs
(
ProcessingKwargs
,
total
=
False
):
_defaults
=
{
"text_kwargs"
:
{
"padding"
:
False
,
},
"images_kwargs"
:
{},
}
class
TarsierProcessor
(
LlavaProcessor
):
def
__call__
(
self
,
images
:
ImageInput
=
None
,
text
:
Union
[
TextInput
,
PreTokenizedInput
,
list
[
TextInput
],
list
[
PreTokenizedInput
]]
=
None
,
audio
=
None
,
videos
=
None
,
**
kwargs
:
Unpack
[
TarsierProcessorKwargs
],
)
->
BatchFeature
:
if
images
is
None
and
text
is
None
:
raise
ValueError
(
"You have to specify at least one of `images` or `text`."
)
# check if images and text inputs are reversed for BC
images
,
text
=
_validate_images_text_input_order
(
images
,
text
)
output_kwargs
=
self
.
_merge_kwargs
(
TarsierProcessorKwargs
,
tokenizer_init_kwargs
=
self
.
tokenizer
.
init_kwargs
,
**
kwargs
,
)
if
images
is
not
None
:
image_inputs
=
self
.
image_processor
(
images
,
**
output_kwargs
[
"images_kwargs"
])
else
:
image_inputs
=
{}
if
isinstance
(
text
,
str
):
text
=
[
text
]
elif
not
isinstance
(
text
,
list
)
and
not
isinstance
(
text
[
0
],
str
):
raise
ValueError
(
"Invalid input text. Please provide a string,"
" or a list of strings"
)
# try to expand inputs in processing if we have the necessary parts
prompt_strings
=
text
if
image_inputs
.
get
(
"pixel_values"
)
is
not
None
:
# Replace the image token with the expanded image token sequence
pixel_values
=
image_inputs
[
"pixel_values"
]
height
,
width
=
get_image_size
(
to_numpy_array
(
pixel_values
[
0
]))
num_image_tokens
=
(
height
//
self
.
patch_size
)
*
(
width
//
self
.
patch_size
+
1
)
+
self
.
num_additional_image_tokens
+
1
if
self
.
vision_feature_select_strategy
==
"default"
:
num_image_tokens
-=
1
prompt_strings
=
[]
for
sample
in
text
:
sample
=
sample
.
replace
(
self
.
image_token
,
self
.
image_token
*
num_image_tokens
)
prompt_strings
.
append
(
sample
)
return_tensors
=
output_kwargs
[
"text_kwargs"
].
pop
(
"return_tensors"
,
None
)
text_inputs
=
self
.
tokenizer
(
prompt_strings
,
**
output_kwargs
[
"text_kwargs"
])
return
BatchFeature
(
data
=
{
**
text_inputs
,
**
image_inputs
},
tensor_type
=
return_tensors
)
class
TarsierMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
,
multimodal_projector_bias
:
bool
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
linear_1
=
ColumnParallelLinear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
multimodal_projector_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_1"
)
self
.
act
=
get_act_fn
(
projector_hidden_act
)
self
.
linear_2
=
RowParallelLinear
(
text_hidden_size
,
text_hidden_size
,
bias
=
multimodal_projector_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_2"
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
class
TarsierProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
)
->
TarsierHfConfig
:
return
self
.
ctx
.
get_hf_config
(
HfLlavaConfig
)
def
get_vision_encoder_info
(
self
)
->
VisionEncoderInfo
:
return
get_vision_encoder_info
(
self
.
get_hf_config
())
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
TarsierProcessor
:
hf_processor
=
self
.
ctx
.
get_hf_processor
(
TarsierProcessor
,
**
kwargs
)
# Patch for patch_size if needed (copied from vLLM LLaVA)
if
hasattr
(
hf_processor
,
'patch_size'
)
and
hf_processor
.
patch_size
is
None
:
patch_size
=
self
.
get_vision_encoder_info
().
get_patch_size
()
hf_processor
.
patch_size
=
patch_size
return
hf_processor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
_apply_feature_select_strategy
(
self
,
strategy
:
str
,
encoder_num_image_tokens
:
int
,
)
->
int
:
if
strategy
==
"default"
:
return
encoder_num_image_tokens
-
1
if
strategy
==
"full"
:
return
encoder_num_image_tokens
msg
=
f
"Unexpected feature select strategy:
{
strategy
!
r
}
"
raise
NotImplementedError
(
msg
)
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vision_encoder_info
=
self
.
get_vision_encoder_info
()
num_projected_patches
=
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
vision_encoder_info
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
),
)
if
num_projected_patches
<=
0
:
default_size
=
self
.
get_image_size_with_most_features
()
num_projected_patches_default
=
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
vision_encoder_info
.
get_num_image_tokens
(
image_width
=
default_size
.
width
,
image_height
=
default_size
.
height
,
),
)
if
num_projected_patches_default
<=
0
:
raise
ValueError
(
"Could not determine a valid number of image patches."
)
num_projected_patches
=
num_projected_patches_default
num_height_patches
=
int
(
math
.
sqrt
(
num_projected_patches
))
total_image_tokens_for_llm
=
num_projected_patches
\
+
num_height_patches
+
1
return
total_image_tokens_for_llm
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_encoder_info
=
self
.
get_vision_encoder_info
()
width
=
height
=
vision_encoder_info
.
get_image_size
()
return
ImageSize
(
width
=
width
,
height
=
height
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
)
def
get_image_newline_idx
(
self
)
->
int
:
return
self
.
get_hf_config
().
image_newline_idx
def
get_image_new_idx
(
self
)
->
int
:
return
self
.
get_hf_config
().
image_new_idx
_I_Tarsier
=
TypeVar
(
"_I_Tarsier"
,
bound
=
TarsierProcessingInfo
)
class
TarsierDummyInputsBuilder
(
LlavaDummyInputsBuilder
[
_I_Tarsier
]):
pass
class
TarsierMultiModalProcessor
(
BaseMultiModalProcessor
[
_I_Tarsier
]):
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
image_token_id
=
hf_config
.
image_token_index
# The <IMAGE> token ID
def
get_replacement
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
if
isinstance
(
images
,
ImageEmbeddingItems
):
num_projected_patches
=
images
.
get_feature_size
(
item_idx
)
# This assumes num_projected_patches is a perfect square
num_height_patches
=
int
(
math
.
sqrt
(
num_projected_patches
))
num_final_image_tokens
=
num_projected_patches
\
+
num_height_patches
+
1
else
:
image_size
=
images
.
get_image_size
(
item_idx
)
num_final_image_tokens
=
self
.
info
.
get_num_image_tokens
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
image_token_id
]
*
num_final_image_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
# Replace each single <IMAGE> token
replacement
=
get_replacement
,
),
]
def
_build_tarsier_hf_info
(
ctx
:
InputProcessingContext
)
->
TarsierProcessingInfo
:
return
TarsierProcessingInfo
(
ctx
)
def
_build_tarsier_hf_processor
(
info
:
_I_Tarsier
,
dummy_inputs
:
BaseDummyInputsBuilder
[
_I_Tarsier
],
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
)
->
BaseMultiModalProcessor
:
if
isinstance
(
info
,
TarsierProcessingInfo
):
return
TarsierMultiModalProcessor
(
info
,
dummy_inputs
,
cache
=
cache
,
)
raise
NotImplementedError
(
type
(
info
))
def
init_vision_tower_for_tarsier
(
hf_config
:
TarsierHfConfig
,
# Use the Tarsier specific config protocol
quant_config
:
Optional
[
QuantizationConfig
],
*
,
require_post_norm
:
Optional
[
bool
]
=
None
,
prefix
:
str
=
""
,
)
->
Union
[
CLIPVisionModel
,
SiglipVisionModel
]:
vision_config
=
hf_config
.
vision_config
feature_layers
=
hf_config
.
vision_feature_layer
base_num_hidden_layers
=
vision_config
.
num_hidden_layers
def
_get_layer_index
(
feature_layer_index
:
int
,
num_hidden_layers_total
:
int
)
->
int
:
if
feature_layer_index
<
0
:
return
num_hidden_layers_total
+
feature_layer_index
+
1
return
feature_layer_index
if
isinstance
(
feature_layers
,
int
):
num_hidden_layers_to_init
=
_get_layer_index
(
feature_layers
,
base_num_hidden_layers
)
elif
isinstance
(
feature_layers
,
(
list
,
tuple
)):
num_hidden_layers_to_init
=
max
(
_get_layer_index
(
idx
,
base_num_hidden_layers
)
for
idx
in
feature_layers
)
else
:
raise
TypeError
(
f
"vision_layer_feature type:
{
type
(
feature_layers
)
}
"
" is not supported"
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_to_init
,
require_post_norm
=
require_post_norm
,
prefix
=
prefix
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_to_init
,
require_post_norm
=
require_post_norm
,
prefix
=
prefix
,
)
msg
=
f
"Unsupported vision config for Tarsier:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
_build_tarsier_hf_processor
,
info
=
_build_tarsier_hf_info
,
dummy_inputs
=
TarsierDummyInputsBuilder
)
class
TarsierForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
:
TarsierHfConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
# Storing the Tarsier-specific HF config
self
.
vision_tower
=
init_vision_tower_for_tarsier
(
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
))
projector_bias
=
getattr
(
config
,
"multimodal_projector_bias"
,
True
)
self
.
multi_modal_projector
=
TarsierMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
projector_bias
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
# Use text_config from Tarsier's main config
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
register_buffer
(
'image_newline_idx_tensor'
,
torch
.
tensor
([
config
.
image_newline_idx
],
dtype
=
torch
.
long
),
persistent
=
False
)
self
.
register_buffer
(
'image_new_idx_tensor'
,
torch
.
tensor
([
config
.
image_new_idx
],
dtype
=
torch
.
long
),
persistent
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
# Assuming 3 channels
actual_dims
=
tuple
(
data
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"batch_size"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
f
"The expected shape of pixel values is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
data
.
shape
)
}
."
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
TarsierImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
TarsierImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
TarsierImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
if
strategy
==
"default"
:
return
image_features
[:,
1
:]
elif
strategy
==
"full"
:
return
image_features
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
Union
[
CLIPVisionModel
,
SiglipVisionModel
],
pixel_values
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
# From vLLM LLaVA, vision tower output handling
image_hidden_states
=
vision_tower
(
pixel_values
)
if
not
isinstance
(
image_hidden_states
,
torch
.
Tensor
):
raise
TypeError
(
f
"image_hidden_states type:
{
type
(
image_hidden_states
)
}
"
" is not supported"
)
def
select_features_fn
(
leaf
:
torch
.
Tensor
):
return
self
.
_select_image_features
(
leaf
,
strategy
=
self
.
config
.
vision_feature_select_strategy
,
)
selected_features
=
cast
(
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]],
json_map_leaves
(
select_features_fn
,
image_hidden_states
),
)
return
selected_features
def
_add_tarsier_split_tokens
(
self
,
projected_image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Implements Tarsier's `add_split_tokens` logic.
"""
num_images
,
num_projected_patches
,
embed_dim
=
\
projected_image_features
.
shape
num_height_patches
=
int
(
math
.
sqrt
(
num_projected_patches
))
num_width_patches
=
num_projected_patches
//
num_height_patches
device
=
projected_image_features
.
device
embedding_layer
=
self
.
language_model
.
model
.
embed_tokens
image_newline_emb
=
embedding_layer
(
self
.
image_newline_idx_tensor
.
to
(
device
)).
squeeze
(
0
)
image_new_emb
=
embedding_layer
(
self
.
image_new_idx_tensor
.
to
(
device
)).
squeeze
(
0
)
try
:
current_image_features_grid
=
projected_image_features
.
view
(
num_images
,
num_height_patches
,
num_width_patches
,
embed_dim
)
except
RuntimeError
as
e
:
raise
RuntimeError
(
"Cannot reshape projected_image_features"
f
" with shape
{
projected_image_features
.
shape
}
"
f
"to (
{
num_images
}
,
{
num_height_patches
}
,"
f
"
{
num_width_patches
}
,
{
embed_dim
}
). "
"Ensure num_projected_patches is compatible"
" with a grid structure. "
f
"num_projected_patches=
{
num_projected_patches
}
, "
f
"derived num_height_patches=
{
num_height_patches
}
. "
)
from
e
image_newline_expanded
=
image_newline_emb
.
expand
(
(
num_images
,
num_height_patches
,
1
,
embed_dim
))
features_with_newlines
=
torch
.
cat
(
[
current_image_features_grid
,
image_newline_expanded
],
dim
=
2
# Concatenate along width dim
)
new_num_patches_after_newline
=
num_projected_patches
\
+
num_height_patches
features_with_newlines_flat
=
features_with_newlines
.
view
(
num_images
,
new_num_patches_after_newline
,
embed_dim
)
image_new_expanded
=
image_new_emb
.
expand
((
num_images
,
1
,
embed_dim
))
final_image_features
=
torch
.
cat
(
[
features_with_newlines_flat
,
image_new_expanded
],
dim
=
1
# Concatenate along patch sequence dim
)
return
final_image_features
def
_process_image_pixels
(
self
,
inputs
:
TarsierImagePixelInputs
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"pixel_values"
]
image_features_selected
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
# type: ignore
if
isinstance
(
image_features_selected
,
torch
.
Tensor
):
projected_features
=
self
.
multi_modal_projector
(
image_features_selected
)
final_features
=
self
.
_add_tarsier_split_tokens
(
projected_features
)
return
final_features
else
:
raise
TypeError
(
f
"_image_pixels_to_features type:"
f
"
{
type
(
image_features_selected
)
}
is not supported"
)
def
_process_image_input
(
self
,
image_input
:
TarsierImageInputs
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
if
image_input
[
"type"
]
==
"image_embeds"
:
projected_features
=
image_input
[
"data"
]
if
isinstance
(
projected_features
,
torch
.
Tensor
):
return
self
.
_add_tarsier_split_tokens
(
projected_features
)
else
:
raise
ValueError
(
"Incorrect type of image_embeds. "
f
"Got type:
{
type
(
projected_features
)
}
. "
)
assert
self
.
vision_tower
is
not
None
return
self
.
_process_image_pixels
(
image_input
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
return
self
.
_process_image_input
(
image_input
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment