Unverified Commit 667b823b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Swin support for any input size (#15986)



* padding done

* correctly return one attention per layer

* almost correct, attentions are not flatten one tuple per stage

* tests green

* doc

* conversations

* reshaping hidden_states

* view in the test

* reshape_hidden_states in Encoder and Model

* new outputs with reshaped_hidden_states

* conversations

* doc

* Update docs/source/model_doc/swin.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* conversations

* fix tests

* minor changes

* resolved conversations

* attentions one per stage

* typo

* typos

* typos

* function signature

* CI

* clean up tests
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 204c54d4
...@@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo ...@@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo
Tips: Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images for the model. - One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png" <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/> alt="drawing" width="600"/>
......
...@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False inputs_dict["output_hidden_states"] = False
...@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths)) expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config # check that output_attentions also work using config
del inputs_dict["output_attentions"] del inputs_dict["output_attentions"]
...@@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths)) self.assertEqual(len(attentions), expected_num_attentions)
if chunk_length is not None: self.assertListEqual(
self.assertListEqual( list(attentions[0].shape[-3:]),
list(attentions[0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], )
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
out_len = len(outputs) out_len = len(outputs)
# Check attention is always last and order is fine # Check attention is always last and order is fine
...@@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(self.model_tester, "num_hidden_states_types"): if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else: else:
added_hidden_states = 1 # also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), len(self.model_tester.depths)) self.assertEqual(len(self_attentions), expected_num_attentions)
if chunk_length is not None:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-4:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
) )
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
...@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
...@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
# Swin has a different seq_length # Swin has a different seq_length
image_size = to_2tuple(self.model_tester.image_size) image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size) patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual( self.assertListEqual(
...@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim], [num_patches, self.model_tester.embed_dim],
) )
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase): ...@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
# verify the logits # verify the logits
expected_shape = torch.Size((1, 1000)) expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment