Unverified Commit 0040469b authored by Tianlin Liu's avatar Tianlin Liu Committed by GitHub
Browse files

Correct attention mask dtype for Flax GPT2 (#25636)

* Correct attention mask dtype

* reformat code

* add a test for boolean mask

* convert test to fast test

* delete unwanted print

* use assertTrue for testing
parent 4b796978
...@@ -753,7 +753,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): ...@@ -753,7 +753,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None: if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1 position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) extended_attention_mask = lax.dynamic_update_slice(
extended_attention_mask, attention_mask.astype("i4"), (0, 0)
)
else: else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
......
...@@ -187,6 +187,26 @@ class FlaxGPT2ModelTester: ...@@ -187,6 +187,26 @@ class FlaxGPT2ModelTester:
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask):
model = model_class_name(config)
output_int_att_mask = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=3,
)
output_bool_att_mask = model.generate(
input_ids=input_ids,
attention_mask=attention_mask.astype(bool),
max_new_tokens=3,
)
self.parent.assertTrue(
(output_bool_att_mask.sequences == output_int_att_mask.sequences).all(),
"Generated response differ between boolean and integer attention mask",
)
@require_flax @require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
...@@ -208,6 +228,13 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes ...@@ -208,6 +228,13 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
model_class_name, config, input_ids, attention_mask model_class_name, config, input_ids, attention_mask
) )
def test_bool_attention_mask_in_generation(self):
for model_class_name in self.all_generative_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_bool_attention_mask_in_generation(
model_class_name, config, input_ids, attention_mask
)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left") tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment