Unverified Commit 80377eb0 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

F.scaled_dot_product_attention support (#26572)



* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ce0bbd51
...@@ -180,6 +180,7 @@ from .import_utils import ( ...@@ -180,6 +180,7 @@ from .import_utils import (
is_torch_mps_available, is_torch_mps_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
......
...@@ -258,6 +258,19 @@ def get_torch_version(): ...@@ -258,6 +258,19 @@ def get_torch_version():
return _torch_version return _torch_version
def is_torch_sdpa_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torchvision_available(): def is_torchvision_available():
return _torchvision_available return _torchvision_available
......
...@@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device) model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1] dummy_input = inputs_dict["input_ids"][:1]
...@@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname,
torch_dtype=torch.bfloat16,
) )
model.to(torch_device) model.to(torch_device)
......
...@@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device) model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
...@@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname,
torch_dtype=torch.bfloat16,
) )
model.to(torch_device) model.to(torch_device)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Falcon model. """ """ Testing suite for the PyTorch Falcon model. """
import tempfile
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
...@@ -26,7 +27,7 @@ from transformers import ( ...@@ -26,7 +27,7 @@ from transformers import (
is_torch_available, is_torch_available,
set_seed, set_seed,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs # The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
# NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
# for name, submodule in model_eager.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# raise ValueError("The eager model should not have SDPA attention layers")
# has_sdpa = False
# for name, submodule in model_sdpa.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# has_sdpa = True
# break
# if not has_sdpa:
# raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch @require_torch
class FalconLanguageGenerationTest(unittest.TestCase): class FalconLanguageGenerationTest(unittest.TestCase):
......
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
import unittest import unittest
from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
require_torch_sdpa,
require_vision, require_vision,
slow, slow,
torch_device, torch_device,
...@@ -309,6 +312,12 @@ class IdeficsModelTester: ...@@ -309,6 +312,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self): def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch @require_torch
...@@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = IdeficsModel.from_pretrained(model_name) model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch @require_torch
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """ """ Testing suite for the PyTorch LLaMA model. """
import tempfile
import unittest import unittest
import pytest import pytest
...@@ -26,6 +27,7 @@ from transformers.testing_utils import ( ...@@ -26,6 +27,7 @@ from transformers.testing_utils import (
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
require_torch_sdpa,
slow, slow,
torch_device, torch_device,
) )
...@@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
output_native = tokenizer.batch_decode(output_native) output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
) )
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
...@@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertListEqual(output_native, output_fa_2) self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
new_model = LlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_sdpa = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch @require_torch
class LlamaIntegrationTest(unittest.TestCase): class LlamaIntegrationTest(unittest.TestCase):
......
...@@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True torch_device
).to(torch_device) )
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
...@@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device) ).to(torch_device)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, tmpdirname,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).to(torch_device) ).to(torch_device)
...@@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase):
"mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.1",
device_map="auto", device_map="auto",
load_in_4bit=True, load_in_4bit=True,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
) )
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
......
...@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native) output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained( model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
) )
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
......
...@@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname,
torch_dtype=torch.bfloat16,
) )
model.to(torch_device) model.to(torch_device)
...@@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device) model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1] dummy_input = inputs_dict[model.main_input_name][:1]
...@@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True configs_no_init.torchscript = True
configs_no_init._attn_implementation = "eager"
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
model.to(torch_device) model.to(torch_device)
...@@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs)[0] outputs = model(**inputs)[0]
input_ids = inputs["input_features"] encoder = model.encoder
encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"] del inputs["input_features"]
encoder = model.encoder if "head_mask" in inputs:
encoder_inputs["head_mask"] = inputs["head_mask"]
if "attention_mask" in inputs:
encoder_inputs["attention_mask"] = inputs["attention_mask"]
if "output_attentions" in inputs:
encoder_inputs["output_attentions"] = inputs["output_attentions"]
with torch.no_grad(): with torch.no_grad():
inputs["encoder_outputs"] = encoder(input_ids) inputs["encoder_outputs"] = encoder(**encoder_inputs)
outputs_embeds = model(**inputs)[0] outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all()) self.assertTrue((outputs_embeds == outputs).all())
......
...@@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase):
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above. # If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual( self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"] missing_keys,
[
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"transformers_version",
],
) )
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0: if len(keys_with_defaults) > 0:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import copy import copy
import gc import gc
...@@ -28,6 +27,7 @@ from collections import defaultdict ...@@ -28,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
from parameterized import parameterized
from pytest import mark from pytest import mark
import transformers import transformers
...@@ -71,6 +71,7 @@ from transformers.testing_utils import ( ...@@ -71,6 +71,7 @@ from transformers.testing_utils import (
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_sdpa,
slow, slow,
torch_device, torch_device,
) )
...@@ -776,102 +777,120 @@ class ModelTesterMixin: ...@@ -776,102 +777,120 @@ class ModelTesterMixin:
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True configs_no_init.torchscript = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) for attn_implementation in ["eager", "sdpa"]:
model.to(torch_device) if attn_implementation == "sdpa" and not model_class._supports_sdpa:
model.eval() continue
inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
try: configs_no_init._attn_implementation = attn_implementation
if model.config.is_encoder_decoder: model = model_class(config=configs_no_init)
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward model.to(torch_device)
main_input = inputs[main_input_name] model.eval()
attention_mask = inputs["attention_mask"] inputs = self._prepare_for_class(inputs_dict, model_class)
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name: main_input_name = model_class.main_input_name
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try: try:
torch.jit.save(traced_model, pt_file_name) if model.config.is_encoder_decoder:
except Exception: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
self.fail("Couldn't save module.") main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"]
try: decoder_input_ids = inputs["decoder_input_ids"]
loaded_model = torch.jit.load(pt_file_name) decoder_attention_mask = inputs["decoder_attention_mask"]
except Exception: model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
self.fail("Couldn't load module.") traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input}
if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest("testing SDPA without attention_mask is not supported")
model(main_input, attention_mask=inputs["attention_mask"])
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input)
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loaded_model.to(torch_device) loaded_model.to(torch_device)
loaded_model.eval() loaded_model.eval()
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict() loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {} non_persistent_buffers = {}
for key in loaded_model_state_dict.keys(): for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys(): if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key] non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = { loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
} }
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers()) model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values(): for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False found_buffer = False
for i, model_buffer in enumerate(model_buffers): for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer): if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True found_buffer = True
break break
self.assertTrue(found_buffer) self.assertTrue(found_buffer)
model_buffers.pop(i) model_buffers.pop(i)
models_equal = True models_equal = True
for layer_name, p1 in model_state_dict.items(): for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict: if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name] p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
models_equal = False models_equal = False
self.assertTrue(models_equal) self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB. # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB) # (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry() self.clear_torch_jit_class_registry()
def test_torch_fx(self): def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -2832,8 +2851,6 @@ class ModelTesterMixin: ...@@ -2832,8 +2851,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_conversion(self): def test_flash_attn_2_conversion(self):
import torch
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -2845,7 +2862,7 @@ class ModelTesterMixin: ...@@ -2845,7 +2862,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device) ).to(torch_device)
for _, module in model.named_modules(): for _, module in model.named_modules():
...@@ -2859,8 +2876,6 @@ class ModelTesterMixin: ...@@ -2859,8 +2876,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -2871,12 +2886,12 @@ class ModelTesterMixin: ...@@ -2871,12 +2886,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model.to(torch_device) model.to(torch_device)
...@@ -2956,8 +2971,6 @@ class ModelTesterMixin: ...@@ -2956,8 +2971,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -2968,12 +2981,12 @@ class ModelTesterMixin: ...@@ -2968,12 +2981,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
) )
model.to(torch_device) model.to(torch_device)
...@@ -3049,8 +3062,6 @@ class ModelTesterMixin: ...@@ -3049,8 +3062,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_left_padding(self): def test_flash_attn_2_generate_left_padding(self):
import torch
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3060,9 +3071,9 @@ class ModelTesterMixin: ...@@ -3060,9 +3071,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True torch_device
).to(torch_device) )
dummy_input = inputs_dict[model.main_input_name] dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]: if dummy_input.dtype in [torch.float32, torch.bfloat16]:
...@@ -3078,7 +3089,10 @@ class ModelTesterMixin: ...@@ -3078,7 +3089,10 @@ class ModelTesterMixin:
) )
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device) ).to(torch_device)
out_fa = model.generate( out_fa = model.generate(
...@@ -3092,8 +3106,6 @@ class ModelTesterMixin: ...@@ -3092,8 +3106,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3103,9 +3115,9 @@ class ModelTesterMixin: ...@@ -3103,9 +3115,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True torch_device
).to(torch_device) )
dummy_input = inputs_dict[model.main_input_name] dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]: if dummy_input.dtype in [torch.float32, torch.bfloat16]:
...@@ -3121,7 +3133,10 @@ class ModelTesterMixin: ...@@ -3121,7 +3133,10 @@ class ModelTesterMixin:
) )
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device) ).to(torch_device)
out_fa = model.generate( out_fa = model.generate(
...@@ -3130,13 +3145,330 @@ class ModelTesterMixin: ...@@ -3130,13 +3145,330 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(out, out_fa)) self.assertTrue(torch.allclose(out, out_fa))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_device == "cpu" and torch_dtype == "float16":
self.skipTest("float16 not supported on cpu")
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30 max_new_tokens = 30
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -3163,7 +3495,7 @@ class ModelTesterMixin: ...@@ -3163,7 +3495,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, tmpdirname,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).to(torch_device) ).to(torch_device)
...@@ -3182,8 +3514,6 @@ class ModelTesterMixin: ...@@ -3182,8 +3514,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_fp32_ln(self): def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3204,7 +3534,7 @@ class ModelTesterMixin: ...@@ -3204,7 +3534,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, tmpdirname,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_flash_attention_2=True, attn_implementation="flash_attention_2",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
load_in_4bit=True, load_in_4bit=True,
) )
...@@ -3282,8 +3612,6 @@ class ModelTesterMixin: ...@@ -3282,8 +3612,6 @@ class ModelTesterMixin:
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_from_config(self): def test_flash_attn_2_from_config(self):
import torch
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3291,7 +3619,7 @@ class ModelTesterMixin: ...@@ -3291,7 +3619,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes # TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config( fa2_model = AutoModelForCausalLM.from_config(
config, use_flash_attention_2=True, torch_dtype=torch.bfloat16 config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device) ).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
...@@ -3313,7 +3641,7 @@ class ModelTesterMixin: ...@@ -3313,7 +3641,7 @@ class ModelTesterMixin:
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False)) self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
fa2_correctly_converted = False fa2_correctly_converted = False
......
...@@ -60,7 +60,13 @@ from transformers.utils import ( ...@@ -60,7 +60,13 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
is_torchdynamo_available,
)
sys.path.append(str(Path(__file__).parent.parent / "utils")) sys.path.append(str(Path(__file__).parent.parent / "utils"))
...@@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase): ...@@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase):
res_compiled = compiled_model(mask, inputs_embeds) res_compiled = compiled_model(mask, inputs_embeds)
self.assertTrue(torch.equal(res_non_compiled, res_compiled)) self.assertTrue(torch.equal(res_non_compiled, res_compiled))
@require_torch
@slow
def test_unmask_unattended_left_padding(self):
attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
expanded_mask = torch.Tensor(
[
[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
reference_output = torch.Tensor(
[
[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
self.assertTrue(torch.equal(result, reference_output))
attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
min_inf = torch.finfo(torch.float32).min
reference_output = torch.Tensor(
[
[
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[min_inf, min_inf, 0, min_inf, min_inf],
[min_inf, min_inf, 0, 0, min_inf],
[min_inf, min_inf, 0, 0, 0],
]
],
[
[
[0, min_inf, min_inf, min_inf, min_inf],
[0, 0, min_inf, min_inf, min_inf],
[0, 0, 0, min_inf, min_inf],
[0, 0, 0, 0, min_inf],
[0, 0, 0, 0, 0],
]
],
[
[
[0, 0, 0, 0, 0],
[min_inf, 0, min_inf, min_inf, min_inf],
[min_inf, 0, 0, min_inf, min_inf],
[min_inf, 0, 0, 0, min_inf],
[min_inf, 0, 0, 0, 0],
]
],
]
)
self.assertTrue(torch.equal(reference_output, result))
@require_torch
@slow
def test_unmask_unattended_right_padding(self):
attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
@slow
def test_unmask_unattended_random_mask(self):
attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
class TestAttentionImplementation(unittest.TestCase):
def test_error_no_sdpa_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
self.assertTrue(
"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
in str(cm.exception)
)
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
def test_error_no_flash_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
def test_not_available_flash(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
)
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import doctest import doctest
import logging import logging
import os import os
import unittest import unittest
from glob import glob
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
...@@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow ...@@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger() logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports inference and training through SDPA for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.") @unittest.skip("Temporarily disable the doc tests.")
@require_torch @require_torch
@require_tf @require_tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment