"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "85b324bee56c32142cc131149d7a92281964641f"
Unverified Commit d24097e0 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix swin embeddings interpolation (#30936)

parent eae2b6b8
......@@ -205,9 +205,7 @@ class DonutSwinEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
......@@ -228,7 +226,7 @@ class DonutSwinEmbeddings(nn.Module):
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
class DonutSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
......@@ -260,21 +258,10 @@ class DonutSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
......
......@@ -197,9 +197,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
def forward(self, pixel_values, interpolate_pos_encoding):
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
......@@ -213,7 +211,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->MaskFormerSwin
class MaskFormerSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
......@@ -245,21 +243,10 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
......
......@@ -291,9 +291,7 @@ class SwinEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
......@@ -345,21 +343,10 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
......
......@@ -334,9 +334,7 @@ class Swinv2Embeddings(nn.Module):
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
......@@ -389,21 +387,10 @@ class Swinv2PatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment