Unverified Commit 0eebd748 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Model][Gemma3] Simplify image input validation (#18710)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 27bebcd8
......@@ -504,18 +504,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
image_size = self.config.vision_config.image_size
expected_dims = (3, image_size, image_size)
if data.shape[1:] != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch is "
f"{expected_dims}. You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment