Unverified Commit 1c21f48a authored by hyenal's avatar hyenal Committed by GitHub
Browse files

add sdpa to ViT [follow up of #29325] (#30555)



remove blank line (+1 squashed commit)
Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8db3] [run-slow]vit_msn
[ead817eca] fix vit msn multi gpu
[d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa88f] doc
[a3ff33e4a] finish implementation
[e20b7b7fb] Update test_modeling_common.py
[e290c5810] Update test_modeling_flax_common.py
[d3af86f46] comment
[ff7dd32d8] more comments
[59b137889] suggestion
[7e2ba6d67] attn_implementation as attribute of the class
[fe66ab71f] minor
[38642b568] Apply suggestions from code review

Accept comments
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[22cde7d52] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[48e137cc6] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[99f4c679f] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00ebb0] all tests are passing locally
[e9e0b82b7] vision encoder/decoder
[4d5076b56] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e26] all tests are passing locally
[bff7fc366] minor
[62f88306f] fix yolo and text_encoder tests
[121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10dd] fix-copies
[ef6c511c4] test vit hybrid
[7d4ba8644] vit hybrid
[66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a031] fixes
[cfde6eb21] fixup
[e77df1ed3] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e22] vit + vit_mae are working
[547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97dfa9] it s the complete opposite...
[aefab37d4] fix more tests
[71802a1b9] fix all torch tests
[40b12eb58] encoder - decoder tests
[941552b69] slow decorator where appropriate
[14d055d80] has_attentions to yolo and msn
[3381fa19f] add correct name
[e261316a7] repo consistency
[31c6d0c08] fixup
[9d214276c] minor fix
[11ed2e1b7] chore
[eca6644c4] add sdpa to vit-based models
[cffbf390b] make fix-copies result
[6468319b0] fix style
[d324cd02a] add sdpa for vit
Co-authored-by: default avatarLiubov Yaronskaya <luba.yaronskaya@gmail.com>
parent 9fd606db
......@@ -71,6 +71,7 @@ class TFDeiTModelTester:
num_labels=3,
scope=None,
encoder_stride=2,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -90,6 +91,7 @@ class TFDeiTModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
......@@ -121,6 +123,7 @@ class TFDeiTModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -70,6 +70,7 @@ class VideoMAEModelTester:
initializer_range=0.02,
mask_ratio=0.9,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -91,6 +92,7 @@ class VideoMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame
self.num_patches_per_frame = (image_size // patch_size) ** 2
......@@ -132,6 +134,7 @@ class VideoMAEModelTester:
decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -197,7 +200,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# hence we define a single mask, which we then repeat for each example in the batch
mask = torch.ones((self.model_tester.num_masks,))
mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool()
batch_size = inputs_dict["pixel_values"].shape[0]
bool_masked_pos = mask.expand(batch_size, -1).bool()
inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)
if return_labels:
......
......@@ -492,7 +492,9 @@ class TFVisionEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
pt_model = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation
)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
......
......@@ -49,6 +49,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -66,6 +67,7 @@ class FlaxViTModelTester(unittest.TestCase):
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
......@@ -87,6 +89,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
return config, pixel_values
......
......@@ -63,6 +63,7 @@ class TFViTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -81,6 +82,7 @@ class TFViTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
......@@ -111,6 +113,7 @@ class TFViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -68,6 +68,8 @@ class ViTModelTester:
initializer_range=0.02,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -87,10 +89,14 @@ class ViTModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -118,6 +124,7 @@ class ViTModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -58,6 +58,7 @@ class ViTHybridModelTester:
initializer_range=0.02,
backbone_featmap_shape=[1, 16, 4, 4],
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -77,6 +78,7 @@ class ViTHybridModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.backbone_featmap_shape = backbone_featmap_shape
self.attn_implementation = attn_implementation
# in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
# the number of patches is based on the feature map of the backbone, which by default uses an output stride
......@@ -122,6 +124,7 @@ class ViTHybridModelTester:
backbone_featmap_shape=self.backbone_featmap_shape,
backbone_config=backbone_config,
backbone=None,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -72,6 +72,7 @@ class TFViTMAEModelTester:
num_labels=3,
mask_ratio=0.6,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -91,6 +92,7 @@ class TFViTMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
......@@ -127,6 +129,7 @@ class TFViTMAEModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
mask_ratio=self.mask_ratio,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -63,8 +63,9 @@ class ViTMAEModelTester:
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
mask_ratio=0.6,
scope=None,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -84,11 +85,15 @@ class ViTMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -120,6 +125,7 @@ class ViTMAEModelTester:
decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -59,6 +59,7 @@ class ViTMSNModelTester:
type_sequence_label_size=10,
initializer_range=0.02,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -77,6 +78,7 @@ class ViTMSNModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.attn_implementation = attn_implementation
# in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
......@@ -106,6 +108,7 @@ class ViTMSNModelTester:
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -62,6 +62,7 @@ class YolosModelTester:
scope=None,
n_targets=8,
num_detection_tokens=10,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
......@@ -83,6 +84,7 @@ class YolosModelTester:
self.scope = scope
self.n_targets = n_targets
self.num_detection_tokens = num_detection_tokens
self.attn_implementation = attn_implementation
# we set the expected sequence length (which is used in several tests)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
......@@ -123,6 +125,7 @@ class YolosModelTester:
initializer_range=self.initializer_range,
num_detection_tokens=self.num_detection_tokens,
num_labels=self.num_labels,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
......
......@@ -2788,7 +2788,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded = model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
......@@ -3724,6 +3726,11 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
......@@ -3861,6 +3868,27 @@ class ModelTesterMixin:
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
......
......@@ -371,7 +371,9 @@ class FlaxModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded = pt_model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
......
......@@ -84,7 +84,7 @@ def check_sdpa_support_list():
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
if arch not in doctext and arch not in doctext.replace("-", "_"):
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment