Unverified Commit 024ad87c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix dtype mismatch in PaliGemma (#6367)

parent aea19f09
......@@ -129,7 +129,7 @@ def run_test(
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["float", "half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
......
......@@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
......
......@@ -19,7 +19,7 @@ from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import SamplerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
......@@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
orig_prompt = llm_inputs.get("prompt")
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
if image_token_str in orig_prompt:
if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
......@@ -214,7 +214,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = vision_tower(pixel_values, output_hidden_states=True)
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
output_hidden_states=True)
selected_image_features = image_outputs.last_hidden_state
......@@ -236,9 +238,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return self.multi_modal_projector(image_features)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object) -> SamplerOutput:
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
......@@ -263,6 +268,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment