"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "34a3c25a3068ab5cdbecb08ddf2866f1209fd2dd"
Unverified Commit a7aca42f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve Swin for VisionEncoderDecoder (#16070)



* Add Swin2Bart test

* Fix swin tests
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 0a057201
...@@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig): ...@@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig):
attribute_map = { attribute_map = {
"num_attention_heads": "num_heads", "num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
} }
def __init__( def __init__(
...@@ -141,4 +142,4 @@ class SwinConfig(PretrainedConfig): ...@@ -141,4 +142,4 @@ class SwinConfig(PretrainedConfig):
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = embed_dim * 8 self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
...@@ -56,8 +56,8 @@ class SwinModelTester: ...@@ -56,8 +56,8 @@ class SwinModelTester:
patch_size=2, patch_size=2,
num_channels=3, num_channels=3,
embed_dim=16, embed_dim=16,
depths=[1], depths=[1, 2, 1],
num_heads=[2], num_heads=[2, 2, 4],
window_size=2, window_size=2,
mlp_ratio=2.0, mlp_ratio=2.0,
qkv_bias=True, qkv_bias=True,
...@@ -73,7 +73,7 @@ class SwinModelTester: ...@@ -73,7 +73,7 @@ class SwinModelTester:
scope=None, scope=None,
use_labels=True, use_labels=True,
type_sequence_label_size=10, type_sequence_label_size=10,
encoder_stride=2, encoder_stride=8,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -139,8 +139,7 @@ class SwinModelTester: ...@@ -139,8 +139,7 @@ class SwinModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
......
...@@ -22,8 +22,10 @@ from datasets import load_dataset ...@@ -22,8 +22,10 @@ from datasets import load_dataset
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester from ..deit.test_modeling_deit import DeiTModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from ..vit.test_modeling_vit import ViTModelTester from ..vit.test_modeling_vit import ViTModelTester
...@@ -35,8 +37,10 @@ if is_torch_available(): ...@@ -35,8 +37,10 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
BartForCausalLM,
BertLMHeadModel, BertLMHeadModel,
DeiTModel, DeiTModel,
SwinModel,
TrOCRForCausalLM, TrOCRForCausalLM,
VisionEncoderDecoderConfig, VisionEncoderDecoderConfig,
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
...@@ -514,6 +518,90 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -514,6 +518,90 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
} }
@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = SwinModel(config).eval()
decoder_model = BartForCausalLM(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32)
model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs
decoder_config, decoder_inputs_dict = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
**decoder_inputs_dict,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
# in Swin, the seq_len equals:
seq_len = encoder_model.config.window_size**2
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len),
)
# there are no published pretrained BART-causal checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch @require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment