Unverified Commit de460e28 authored by Kristen Pereira's avatar Kristen Pereira Committed by GitHub
Browse files

Add dynamic resolution input/interpolate position embedding to deit (#31131)



* Added interpolate pos encoding feature and test to deit

* Added interpolate pos encoding feature and test for deit TF model

* readded accidentally delted test for multi_gpu

* storing only patch_size instead of entire config and removed commented code

* Update modeling_tf_deit.py to remove extra line
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent d64e4da7
...@@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module): ...@@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module):
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = embeddings.size() batch_size, seq_length, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
...@@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module): ...@@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module): ...@@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module):
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2) x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return x
...@@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r""" ...@@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
""" """
...@@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
...@@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel):
if pixel_values.dtype != expected_dtype: if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype) pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, MaskedImageModelingOutput]: ) -> Union[tuple, MaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
...@@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): ...@@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]: ) -> Union[tuple, ImageClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): ...@@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]: ) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer): ...@@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer):
with tf.name_scope(self.dropout.name): with tf.name_scope(self.dropout.name):
self.dropout.build(None) self.dropout.build(None)
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = tf.reshape(
patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
)
patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
return tf.concat(
[tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
)
def call( def call(
self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False self,
pixel_values: tf.Tensor,
bool_masked_pos: tf.Tensor | None = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
_, height, width, _ = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = shape_list(embeddings) batch_size, seq_length, _ = shape_list(embeddings)
...@@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer): ...@@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer):
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0) distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1) embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
...@@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer): ...@@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer):
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values) x = self.projection(pixel_values)
batch_size, height, width, num_channels = shape_list(x) batch_size, height, width, num_channels = shape_list(x)
x = tf.reshape(x, (batch_size, height * width, num_channels)) x = tf.reshape(x, (batch_size, height * width, num_channels))
...@@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer): ...@@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer): ...@@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask) head_mask = self.get_head_mask(head_mask)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, training=training) embedding_output = self.embeddings(
pixel_values,
bool_masked_pos=bool_masked_pos,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r""" ...@@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel): ...@@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[Tuple, TFBaseModelOutputWithPooling]: ) -> Union[Tuple, TFBaseModelOutputWithPooling]:
outputs = self.deit( outputs = self.deit(
...@@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel): ...@@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
return outputs return outputs
...@@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tuple, TFMaskedImageModelingOutput]: ) -> Union[tuple, TFMaskedImageModelingOutput]:
r""" r"""
...@@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
...@@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati ...@@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tf.Tensor, TFImageClassifierOutput]: ) -> Union[tf.Tensor, TFImageClassifierOutput]:
r""" r"""
...@@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati ...@@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
...@@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel): ...@@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]: ) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel): ...@@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
......
...@@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase): ...@@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
torch_device
)
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
@slow @slow
@require_accelerate @require_accelerate
@require_torch_accelerator @require_torch_accelerator
......
...@@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase): ...@@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase):
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861]) expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False)
# forward pass
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment