Unverified Commit 0eebd748 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Model][Gemma3] Simplify image input validation (#18710)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 27bebcd8
...@@ -504,18 +504,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -504,18 +504,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return next(self.parameters()).dtype return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size image_size = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, image_size, image_size)
if data.shape[1:] != expected_dims:
def _validate_shape(d: torch.Tensor): raise ValueError(
if d.shape != expected_dims: "The expected shape of pixel values per image per batch is "
raise ValueError( f"{expected_dims}. You supplied {tuple(data.shape)}.")
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment