Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c0d8f163
Unverified
Commit
c0d8f163
authored
Aug 05, 2024
by
Jungho Christopher Cho
Committed by
GitHub
Aug 05, 2024
Browse files
[Model] SiglipVisionModel ported from transformers (#6942)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
cc08fc72
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
650 additions
and
53 deletions
+650
-53
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+2
-1
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+27
-52
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+621
-0
No files found.
examples/offline_inference_vision_language.py
View file @
c0d8f163
...
...
@@ -65,7 +65,8 @@ def run_phi3v(question):
# PaliGemma
def
run_paligemma
(
question
):
prompt
=
question
# PaliGemma has special prompt format for VQA
prompt
=
"caption en"
llm
=
LLM
(
model
=
"google/paligemma-3b-mix-224"
)
return
llm
,
prompt
...
...
vllm/model_executor/models/paligemma.py
View file @
c0d8f163
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
torch
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
,
SiglipVisionConfig
,
SiglipVisionModel
from
transformers
import
PaliGemmaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
...
...
@@ -18,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsVision
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -32,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
text_config
=
hf_config
.
text_config
return
text_config
.
num_image_tokens
def
dummy_seq_data_for_paligemma
(
hf_config
:
PaliGemmaConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
hf_config
.
text_config
.
num_image_tokens
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_paligemma
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
vision_config
=
hf_config
.
vision_config
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
get_max_siglip_image_tokens
(
vision_config
)
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
seq_data
=
dummy_seq_data_for_
paligemma
(
hf
_config
,
seq_data
=
dummy_seq_data_for_
siglip
(
vision
_config
,
seq_len
,
image_token_id
=
hf_config
.
image_token_index
,
)
mm_data
=
dummy_image_for_
paligemma
(
vision_config
)
mm_data
=
dummy_image_for_
siglip
(
vision_config
)
return
seq_data
,
mm_data
...
...
@@ -208,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
image_outputs
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
),
output_hidden_states
=
True
)
selected_image_features
=
image_outputs
.
last_hidden_state
image_features
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
))
return
selected_
image_features
return
image_features
def
_process_image_pixels
(
self
,
inputs
:
PaliGemmaImagePixelInputs
)
->
torch
.
Tensor
:
self
,
inputs
:
PaliGemmaImagePixelInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
,
)
def
_process_image_input
(
self
,
image_input
:
PaliGemmaImageInputs
)
->
torch
.
Tensor
:
self
,
image_input
:
PaliGemmaImageInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
image_features
=
self
.
_process_image_pixels
(
image_input
,
)
return
self
.
multi_modal_projector
(
image_features
)
...
...
vllm/model_executor/models/siglip.py
0 → 100644
View file @
c0d8f163
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment