"src/array/cuda/array_index_select.hip" did not exist on "7b9afbfa28e9441fd3bfbe4e3c24eb62816cb1a7"
Unverified Commit 7b02c326 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Bugfix](gemma3_mm): handle flatten_batch constraint for multiple images (#6562)

parent fefa19fe
...@@ -288,13 +288,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -288,13 +288,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"MM inputs where only some items are precomputed." "MM inputs where only some items are precomputed."
) )
return torch.concat([item.precomputed_features for item in items]) return torch.concat([item.precomputed_features for item in items])
pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0
)
pixel_values = pixel_values.to(device=self.vision_tower.device)
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values) # Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = []
for pixel_value in all_pixel_values:
# Add batch dimension for single image processing
pixel_value_batch = pixel_value.unsqueeze(0)
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
vision_outputs_list.append(vision_output)
# Concatenate all vision outputs
vision_outputs = torch.cat(vision_outputs_list, dim=0)
image_features = self.multi_modal_projector(vision_outputs) image_features = self.multi_modal_projector(vision_outputs)
return image_features return image_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment