Unverified Commit d3c9d0e5 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[ViT, BEiT, DeiT, DPT] Improve code (#16799)



* Improve code

* Fix bugs

* Fix another bug

* Clean up DTP as well

* Update DPT model outputs
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 3785f466
...@@ -542,7 +542,8 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -542,7 +542,8 @@ class DeiTModel(DeiTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
...@@ -662,7 +663,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -662,7 +663,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict: if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput( return MaskedLMOutput(
...@@ -775,7 +776,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): ...@@ -775,7 +776,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput( return ImageClassifierOutput(
...@@ -882,7 +883,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -882,7 +883,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
logits = (cls_logits + distillation_logits) / 2 logits = (cls_logits + distillation_logits) / 2
if not return_dict: if not return_dict:
output = (logits, cls_logits, distillation_logits) + outputs[2:] output = (logits, cls_logits, distillation_logits) + outputs[1:]
return output return output
return DeiTForImageClassificationWithTeacherOutput( return DeiTForImageClassificationWithTeacherOutput(
......
...@@ -750,7 +750,8 @@ class DPTModel(DPTPreTrainedModel): ...@@ -750,7 +750,8 @@ class DPTModel(DPTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
...@@ -938,7 +939,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel): ...@@ -938,7 +939,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
hidden_states = outputs.hidden_states if return_dict else outputs[2] hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices # only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings # note that the hidden_states also include the initial embeddings
...@@ -956,9 +957,9 @@ class DPTForDepthEstimation(DPTPreTrainedModel): ...@@ -956,9 +957,9 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (predicted_depth,) + outputs[2:] output = (predicted_depth,) + outputs[1:]
else: else:
output = (predicted_depth,) + outputs[3:] output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return DepthEstimatorOutput( return DepthEstimatorOutput(
...@@ -1083,7 +1084,7 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel): ...@@ -1083,7 +1084,7 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
hidden_states = outputs.hidden_states if return_dict else outputs[2] hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features based on config.backbone_out_indices # only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings # note that the hidden_states also include the initial embeddings
...@@ -1120,9 +1121,9 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel): ...@@ -1120,9 +1121,9 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (logits,) + outputs[2:] output = (logits,) + outputs[1:]
else: else:
output = (logits,) + outputs[3:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SemanticSegmenterOutput( return SemanticSegmenterOutput(
......
...@@ -585,7 +585,8 @@ class ViTModel(ViTPreTrainedModel): ...@@ -585,7 +585,8 @@ class ViTModel(ViTPreTrainedModel):
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
...@@ -706,7 +707,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -706,7 +707,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict: if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput( return MaskedLMOutput(
...@@ -798,8 +799,9 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -798,8 +799,9 @@ class ViTForImageClassification(ViTPreTrainedModel):
elif self.config.problem_type == "multi_label_classification": elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput( return ImageClassifierOutput(
......
...@@ -41,7 +41,7 @@ if is_torch_available(): ...@@ -41,7 +41,7 @@ if is_torch_available():
BeitForSemanticSegmentation, BeitForSemanticSegmentation,
BeitModel, BeitModel,
) )
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -96,6 +96,10 @@ class BeitModelTester: ...@@ -96,6 +96,10 @@ class BeitModelTester:
self.out_indices = out_indices self.out_indices = out_indices
self.num_labels = num_labels self.num_labels = num_labels
# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -132,22 +136,16 @@ class BeitModelTester: ...@@ -132,22 +136,16 @@ class BeitModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) self.parent.assertEqual(
image_size = to_2tuple(self.image_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
patch_size = to_2tuple(self.patch_size) )
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels): def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config) model = BeitForMaskedImageModeling(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -312,16 +310,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -312,16 +310,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
# in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) # BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size) seq_len = self.model_tester.expected_seq_length
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -332,7 +322,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -332,7 +322,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -349,7 +339,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -349,7 +339,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
out_len = len(outputs) out_len = len(outputs)
...@@ -369,7 +359,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -369,7 +359,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):
...@@ -381,7 +371,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -381,7 +371,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
...@@ -389,10 +379,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -389,10 +379,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# BEiT has a different seq_length # BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size) seq_length = self.model_tester.expected_seq_length
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
......
...@@ -75,6 +75,10 @@ class FlaxBeitModelTester(unittest.TestCase): ...@@ -75,6 +75,10 @@ class FlaxBeitModelTester(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -104,20 +108,14 @@ class FlaxBeitModelTester(unittest.TestCase): ...@@ -104,20 +108,14 @@ class FlaxBeitModelTester(unittest.TestCase):
model = FlaxBeitModel(config=config) model = FlaxBeitModel(config=config)
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) self.parent.assertEqual(
image_size = (self.image_size, self.image_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
patch_size = (self.patch_size, self.patch_size) )
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels): def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config) model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -151,13 +149,11 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -151,13 +149,11 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# We need to override this test because in Beit, the seq_len equals the number of patches + 1 # We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
num_patches = (config.image_size // config.patch_size) ** 2 seq_length = self.model_tester.expected_seq_length
seq_length = num_patches + 1
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -209,7 +205,7 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -209,7 +205,7 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names = ["pixel_values"] expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
# We neeed to override this test because Beit expects pixel_values instead of input_ids # We need to override this test because Beit expects pixel_values instead of input_ids
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -234,12 +230,10 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -234,12 +230,10 @@ class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
# We need to override this test because in Beit, the seq_len equals the number of patches + 1 # We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2 seq_length = self.model_tester.expected_seq_length
seq_length = num_patches + 1 # we add 1 for the [CLS] token
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states
......
...@@ -41,7 +41,7 @@ if is_torch_available(): ...@@ -41,7 +41,7 @@ if is_torch_available():
DeiTForMaskedImageModeling, DeiTForMaskedImageModeling,
DeiTModel, DeiTModel,
) )
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -92,6 +92,10 @@ class DeiTModelTester: ...@@ -92,6 +92,10 @@ class DeiTModelTester:
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
# in DeiT, the expected seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 2
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -125,11 +129,9 @@ class DeiTModelTester: ...@@ -125,11 +129,9 @@ class DeiTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 2 (we add 2 for the [CLS] and distillation tokens) self.parent.assertEqual(
image_size = to_2tuple(self.image_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
patch_size = to_2tuple(self.patch_size) )
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 2, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -212,16 +214,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -212,16 +214,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
# in DeiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) seq_len = self.model_tester.expected_seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 2
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -232,7 +225,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -232,7 +225,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -243,18 +236,12 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -243,18 +236,12 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
out_len = len(outputs) out_len = len(outputs)
...@@ -267,26 +254,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -267,26 +254,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"): self.assertEqual(out_len + 1, len(outputs))
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):
...@@ -298,18 +273,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -298,18 +273,14 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# DeiT has a different seq_length seq_length = self.model_tester.expected_seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 2
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
......
...@@ -81,6 +81,9 @@ class DPTModelTester: ...@@ -81,6 +81,9 @@ class DPTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.scope = scope self.scope = scope
# expected sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -115,9 +118,9 @@ class DPTModelTester: ...@@ -115,9 +118,9 @@ class DPTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) self.parent.assertEqual(
num_patches = (config.image_size // config.patch_size) ** 2 result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) )
def create_and_check_for_depth_estimation(self, config, pixel_values, labels): def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -206,8 +209,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -206,8 +209,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True config.return_dict = True
# in DPT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) # in DPT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (config.image_size // config.patch_size) ** 2 seq_len = self.model_tester.expected_seq_length
seq_len = num_patches + 1
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -274,8 +276,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -274,8 +276,7 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# DPT has a different seq_length # DPT has a different seq_length
num_patches = (config.image_size // config.patch_size) ** 2 seq_len = self.model_tester.expected_seq_length
seq_len = num_patches + 1
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
......
...@@ -67,6 +67,10 @@ class FlaxViTModelTester(unittest.TestCase): ...@@ -67,6 +67,10 @@ class FlaxViTModelTester(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -120,13 +124,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -120,13 +124,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# We need to override this test because in ViT, the seq_len equals the number of patches + 1 # We need to override this test because in ViT, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
num_patches = (config.image_size // config.patch_size) ** 2 seq_length = self.model_tester.expected_seq_length
seq_length = num_patches + 1
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -203,12 +205,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -203,12 +205,11 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
# We need to override this test because in ViT, the seq_len equals the number of patches + 1 # We need to override this test because in ViT, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token seq_length = self.model_tester.expected_seq_length
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states
......
...@@ -32,7 +32,6 @@ if is_tf_available(): ...@@ -32,7 +32,6 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFViTForImageClassification, TFViTModel from transformers import TFViTForImageClassification, TFViTModel
from transformers.models.vit.modeling_tf_vit import to_2tuple
if is_vision_available(): if is_vision_available():
...@@ -81,6 +80,10 @@ class TFViTModelTester: ...@@ -81,6 +80,10 @@ class TFViTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -111,20 +114,18 @@ class TFViTModelTester: ...@@ -111,20 +114,18 @@ class TFViTModelTester:
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
model = TFViTModel(config=config) model = TFViTModel(config=config)
result = model(pixel_values, training=False) result = model(pixel_values, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) self.parent.assertEqual(
image_size = to_2tuple(self.image_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
patch_size = to_2tuple(self.patch_size) )
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
# Test with an image with different size than the one specified in config. # Test with an image with different size than the one specified in config.
image_size = self.image_size // 2 image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size] pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False) result = model(pixel_values, interpolate_pos_encoding=True, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) expected_seq_length = (image_size // self.patch_size) ** 2 + 1
image_size = to_2tuple(image_size) self.parent.assertEqual(
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) result.last_hidden_state.shape, (self.batch_size, expected_seq_length, self.hidden_size)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) )
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -210,12 +211,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -210,12 +211,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.use_cache = True config.use_cache = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size) seq_len = self.model_tester.expected_seq_length
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
...@@ -228,10 +224,6 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -228,10 +224,6 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
model = tf.keras.models.load_model(saved_model_dir) model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict) outputs = model(class_inputs_dict)
if self.is_encoder_decoder:
output_hidden_states = outputs["encoder_hidden_states"]
output_attentions = outputs["encoder_attentions"]
else:
output_hidden_states = outputs["hidden_states"] output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"] output_attentions = outputs["attentions"]
...@@ -250,7 +242,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -250,7 +242,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(output_attentions[0].shape[-3:]), list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
def test_attention_outputs(self): def test_attention_outputs(self):
...@@ -258,12 +250,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -258,12 +250,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.return_dict = True config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size) seq_len = self.model_tester.expected_seq_length
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -271,7 +258,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -271,7 +258,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.return_dict = True config.return_dict = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -279,12 +266,12 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -279,12 +266,12 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
config.output_attentions = True config.output_attentions = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
out_len = len(outputs) out_len = len(outputs)
...@@ -294,20 +281,14 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -294,20 +281,14 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"): self.assertEqual(out_len + 1, len(outputs))
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):
...@@ -316,7 +297,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -316,7 +297,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
...@@ -324,10 +305,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -324,10 +305,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length # ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size) seq_length = self.model_tester.expected_seq_length
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
......
...@@ -31,7 +31,7 @@ if is_torch_available(): ...@@ -31,7 +31,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -59,7 +59,6 @@ class ViTModelTester: ...@@ -59,7 +59,6 @@ class ViTModelTester:
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
num_labels=3,
scope=None, scope=None,
encoder_stride=2, encoder_stride=2,
): ):
...@@ -82,6 +81,10 @@ class ViTModelTester: ...@@ -82,6 +81,10 @@ class ViTModelTester:
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
# in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -115,11 +118,9 @@ class ViTModelTester: ...@@ -115,11 +118,9 @@ class ViTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) self.parent.assertEqual(
image_size = to_2tuple(self.image_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
patch_size = to_2tuple(self.patch_size) )
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -201,16 +202,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -201,16 +202,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) seq_len = self.model_tester.expected_seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -221,7 +213,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -221,7 +213,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -232,18 +224,12 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -232,18 +224,12 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
out_len = len(outputs) out_len = len(outputs)
...@@ -256,26 +242,14 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -256,26 +242,14 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"): self.assertEqual(out_len + 1, len(outputs))
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, seq_len, seq_len],
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):
...@@ -287,22 +261,16 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -287,22 +261,16 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size], [self.model_tester.expected_seq_length, self.model_tester.hidden_size],
) )
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment