Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
024ad87c
Unverified
Commit
024ad87c
authored
Jul 12, 2024
by
Cyrus Leung
Committed by
GitHub
Jul 12, 2024
Browse files
[Bugfix] Fix dtype mismatch in PaliGemma (#6367)
parent
aea19f09
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
5 deletions
+12
-5
tests/models/test_paligemma.py
tests/models/test_paligemma.py
+1
-1
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+1
-0
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+10
-4
No files found.
tests/models/test_paligemma.py
View file @
024ad87c
...
@@ -129,7 +129,7 @@ def run_test(
...
@@ -129,7 +129,7 @@ def run_test(
[
0.25
,
0.5
,
1.0
],
[
0.25
,
0.5
,
1.0
],
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
...
...
vllm/model_executor/models/gemma.py
View file @
024ad87c
...
@@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
...
@@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
...
...
vllm/model_executor/models/paligemma.py
View file @
024ad87c
...
@@ -19,7 +19,7 @@ from vllm.model_executor.models.gemma import GemmaModel
...
@@ -19,7 +19,7 @@ from vllm.model_executor.models.gemma import GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
SamplerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
.interfaces
import
SupportsVision
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
from
.utils
import
merge_vision_embeddings
...
@@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
orig_prompt
=
llm_inputs
.
get
(
"prompt"
)
orig_prompt
=
llm_inputs
.
get
(
"prompt"
)
orig_prompt_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
orig_prompt_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
if
image_token_str
in
orig_prompt
:
if
orig_prompt
is
not
None
and
image_token_str
in
orig_prompt
:
logger
.
warning
(
logger
.
warning
(
"The image token '%s' was detected in the prompt and "
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
"will be removed. Please follow the proper prompt format"
...
@@ -214,7 +214,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -214,7 +214,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
image_outputs
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
),
output_hidden_states
=
True
)
selected_image_features
=
image_outputs
.
last_hidden_state
selected_image_features
=
image_outputs
.
last_hidden_state
...
@@ -236,9 +238,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -236,9 +238,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return
self
.
multi_modal_projector
(
image_features
)
return
self
.
multi_modal_projector
(
image_features
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
)
->
SamplerOutput
:
**
kwargs
:
object
)
->
SamplerOutput
:
parsed_image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
parsed_image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
@@ -263,6 +268,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -263,6 +268,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
positions
,
positions
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment