Unverified Commit 5f1fcc29 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[Idefics2] - Fix FA2 call for Perceiver layer (#32275)

* Fix FA2 call for Perciever layer

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2

* Fix up

* [run_slow] idefics2

* [run_slow] idefics2

* [run_slow] idefics2
parent b75ad566
......@@ -894,7 +894,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=False,
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
......
......@@ -29,7 +29,14 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -491,13 +498,13 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
torch.cuda.empty_cache()
@slow
@unittest.skip("Test hits OOM on CI - https://github.com/huggingface/transformers/issues/32288")
def test_integration_test(self):
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.to(torch_device)
# Create inputs
text = "<image>In this image, we see"
......@@ -517,7 +524,8 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base", load_in_4bit=True, device_map="auto"
"HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
)
# Create pixel inputs
......@@ -530,3 +538,37 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
def test_flash_attn_2_eager_equivalence(self):
# Create inputs
text = "<image>In this image, we see"
images = self.image1
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
inputs.to(torch_device)
# Eager model
model_eager = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="eager",
load_in_4bit=True,
)
generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)
del model_eager
# Flash Attention 2 model
model_flash_attention_2 = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
attn_implementation="flash_attention_2",
load_in_4bit=True,
)
generated_ids_flash_attention_2 = model_flash_attention_2.generate(**inputs, max_new_tokens=10)
generated_texts_flash_attention_2 = self.processor.batch_decode(
generated_ids_flash_attention_2, skip_special_tokens=True
)
self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment