Unverified Commit 667b823b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Swin support for any input size (#15986)



* padding done

* correctly return one attention per layer

* almost correct, attentions are not flatten one tuple per stage

* tests green

* doc

* conversations

* reshaping hidden_states

* view in the test

* reshape_hidden_states in Encoder and Model

* new outputs with reshaped_hidden_states

* conversations

* doc

* Update docs/source/model_doc/swin.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* conversations

* fix tests

* minor changes

* resolved conversations

* attentions one per stage

* typo

* typos

* typos

* function signature

* CI

* clean up tests
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 204c54d4
......@@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo
Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/>
......
......@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
......@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths))
attentions = outputs.attentions
expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
......@@ -260,15 +252,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
......@@ -286,21 +272,15 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), expected_num_attentions)
self.assertEqual(len(self_attentions), len(self.model_tester.depths))
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
......@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
......@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
# Swin has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual(
......@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment