Unverified Commit 5f1fcc29 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[Idefics2] - Fix FA2 call for Perceiver layer (#32275)

* Fix FA2 call for Perciever layer

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2

* Fix up

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2
parent b75ad566
...@@ -894,7 +894,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): ...@@ -894,7 +894,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
sliding_window=False, sliding_window=None,
is_causal=self.is_causal, is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask, use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
......
...@@ -29,7 +29,14 @@ from transformers import ( ...@@ -29,7 +29,14 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -491,13 +498,13 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -491,13 +498,13 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@slow @slow
@unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288")
def test_integration_test(self): def test_integration_test(self):
model = Idefics2ForConditionalGeneration.from_pretrained( model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base", "HuggingFaceM4/idefics2-8b-base",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
) )
model.to(torch_device)
# Create inputs # Create inputs
text = "<image>In this image, we see" text = "<image>In this image, we see"
...@@ -517,7 +524,8 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -517,7 +524,8 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
def test_integration_test_4bit(self): def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used # Let' s make sure we test the preprocessing to replace what is used
model = Idefics2ForConditionalGeneration.from_pretrained( model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto" "HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
) )
# Create pixel inputs # Create pixel inputs
...@@ -530,3 +538,37 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -530,3 +538,37 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River," expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text) self.assertEqual(generated_texts[0], expected_generated_text)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
def test_flash_attn_2_eager_equivalence(self):
# Create inputs
text = "<image>In this image, we see"
images = self.image1
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
inputs.to(torch_device)
# Eager model
model_eager = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="eager",
load_in_4bit=True,
)
generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)
del model_eager
# Flash Attention 2 model
model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="flash_attention_2",
load_in_4bit=True,
)
generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10)
generated_texts_flash_attention_2 = self.processor.batch_decode(
generated_ids_flash_attention_2, skip_special_tokens=True
)
self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment