"ml/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1fd4cb87b202ddacbc45c2a14284acfafc552101"
Unverified Commit 83f9314d authored by Jim Allanson's avatar Jim Allanson Committed by GitHub
Browse files

fix: cast input pixels to appropriate dtype for image_to_text pipelines (#24947)

* fix: cast input pixels to appropriate dtype for image_to_text tasks

* fix: add casting to pixel inputs of additional models after running copy checks
parent 1c7e5e23
...@@ -1022,7 +1022,8 @@ class AltCLIPVisionEmbeddings(nn.Module): ...@@ -1022,7 +1022,8 @@ class AltCLIPVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
...@@ -246,7 +246,7 @@ class BlipVisionEmbeddings(nn.Module): ...@@ -246,7 +246,7 @@ class BlipVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
......
...@@ -109,7 +109,7 @@ class Blip2VisionEmbeddings(nn.Module): ...@@ -109,7 +109,7 @@ class Blip2VisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
......
...@@ -284,7 +284,8 @@ class BridgeTowerVisionEmbeddings(nn.Module): ...@@ -284,7 +284,8 @@ class BridgeTowerVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
...@@ -196,7 +196,8 @@ class ChineseCLIPVisionEmbeddings(nn.Module): ...@@ -196,7 +196,8 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
...@@ -192,7 +192,8 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -192,7 +192,8 @@ class CLIPVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
...@@ -628,7 +628,8 @@ class GitVisionEmbeddings(nn.Module): ...@@ -628,7 +628,8 @@ class GitVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
...@@ -110,7 +110,7 @@ class InstructBlipVisionEmbeddings(nn.Module): ...@@ -110,7 +110,7 @@ class InstructBlipVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
......
...@@ -143,7 +143,8 @@ class XCLIPVisionEmbeddings(nn.Module): ...@@ -143,7 +143,8 @@ class XCLIPVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1) class_embeds = self.class_embedding.expand(batch_size, 1, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment