Unverified Commit bc764f42 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: left-padding test, revisited (#29515)



* left-padding test revisited

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 631fa7bf
......@@ -1833,49 +1833,68 @@ class GenerationTesterMixin:
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
def test_left_padding_compatibility(self):
# The check done in this test is fairly difficult -- depending on the model architecture, passing the right
# position index for the position embeddings can still result in a different output, due to numerical masking.
# On the other hand, for some types of position embeddings, an incorrect position index can have a minimal
# impact on the output.
# There are two tricks employed to check whether left-padding compatibility is in place:
# 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the
# padding size to 1.
# 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check
# multiple times with random inputs, and it has to pass with all of them.
# NOTE: because of 2), there is some chance of false positives in this test.
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _, _, _ = self._get_input_ids_and_config()
if config.is_encoder_decoder:
continue # skip for encoder-decoder models -- they don't need left-padding compatibility
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict().keys()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)
# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in decoder_only_classes:
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
no_failures = True
for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness
_, input_ids, attention_mask, _ = self._get_input_ids_and_config()
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
pad_size = (input_ids.shape[0], 1)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1
position_ids.masked_fill_(padded_attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7):
no_failures = False
break
self.assertTrue(no_failures)
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
......
......@@ -1527,7 +1527,3 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -818,7 +818,3 @@ class BigBirdPegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTeste
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -569,7 +569,3 @@ class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -568,7 +568,3 @@ class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTest
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -249,10 +249,6 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = CTRLModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
@require_torch
class CTRLModelLanguageGenerationTest(unittest.TestCase):
......
......@@ -895,7 +895,3 @@ class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -736,7 +736,3 @@ class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, u
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -823,7 +823,3 @@ class MvpStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, uni
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -596,7 +596,3 @@ class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -669,7 +669,3 @@ class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
......@@ -1146,10 +1146,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
# decoder cannot keep gradients
return
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
@require_torch
class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
......
......@@ -3230,7 +3230,3 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
@unittest.skip("The model doesn't support fast init from base")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
def test_left_padding_compatibility(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment