"vscode:/vscode.git/clone" did not exist on "393447baa4a1ddda57ff785ceb2a97a00ceb23f3"
Unverified Commit 80377eb0 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

F.scaled_dot_product_attention support (#26572)



* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ce0bbd51
......@@ -180,6 +180,7 @@ from .import_utils import (
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
......
......@@ -258,6 +258,19 @@ def get_torch_version():
return _torch_version
def is_torch_sdpa_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torchvision_available():
return _torchvision_available
......
......@@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
......@@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......
......@@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
......@@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......
......@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Falcon model. """
import tempfile
import unittest
from parameterized import parameterized
......@@ -26,7 +27,7 @@ from transformers import (
is_torch_available,
set_seed,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
# NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
# for name, submodule in model_eager.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# raise ValueError("The eager model should not have SDPA attention layers")
# has_sdpa = False
# for name, submodule in model_sdpa.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# has_sdpa = True
# break
# if not has_sdpa:
# raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):
......
......@@ -16,11 +16,14 @@
import unittest
from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
......@@ -309,6 +312,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
......@@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
import tempfile
import unittest
import pytest
......@@ -26,6 +27,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
......@@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
......@@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
new_model = LlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_sdpa = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class LlamaIntegrationTest(unittest.TestCase):
......
......@@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
......@@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
......@@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
......@@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase):
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
......
......@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
......
......@@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......@@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
......@@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init._attn_implementation = "eager"
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
......@@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
with torch.no_grad():
outputs = model(**inputs)[0]
input_ids = inputs["input_features"]
encoder = model.encoder
encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"]
encoder = model.encoder
if "head_mask" in inputs:
encoder_inputs["head_mask"] = inputs["head_mask"]
if "attention_mask" in inputs:
encoder_inputs["attention_mask"] = inputs["attention_mask"]
if "output_attentions" in inputs:
encoder_inputs["output_attentions"] = inputs["output_attentions"]
with torch.no_grad():
inputs["encoder_outputs"] = encoder(input_ids)
inputs["encoder_outputs"] = encoder(**encoder_inputs)
outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all())
......
......@@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase):
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
missing_keys,
[
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"transformers_version",
],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
......
This diff is collapsed.
......@@ -60,7 +60,13 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
is_torchdynamo_available,
)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
......@@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase):
res_compiled = compiled_model(mask, inputs_embeds)
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
@require_torch
@slow
def test_unmask_unattended_left_padding(self):
attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
expanded_mask = torch.Tensor(
[
[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
reference_output = torch.Tensor(
[
[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
self.assertTrue(torch.equal(result, reference_output))
attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
min_inf = torch.finfo(torch.float32).min
reference_output = torch.Tensor(
[
[
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[min_inf, min_inf, 0, min_inf, min_inf],
[min_inf, min_inf, 0, 0, min_inf],
[min_inf, min_inf, 0, 0, 0],
]
],
[
[
[0, min_inf, min_inf, min_inf, min_inf],
[0, 0, min_inf, min_inf, min_inf],
[0, 0, 0, min_inf, min_inf],
[0, 0, 0, 0, min_inf],
[0, 0, 0, 0, 0],
]
],
[
[
[0, 0, 0, 0, 0],
[min_inf, 0, min_inf, min_inf, min_inf],
[min_inf, 0, 0, min_inf, min_inf],
[min_inf, 0, 0, 0, min_inf],
[min_inf, 0, 0, 0, 0],
]
],
]
)
self.assertTrue(torch.equal(reference_output, result))
@require_torch
@slow
def test_unmask_unattended_right_padding(self):
attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
@slow
def test_unmask_unattended_random_mask(self):
attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
class TestAttentionImplementation(unittest.TestCase):
def test_error_no_sdpa_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
self.assertTrue(
"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
in str(cm.exception)
)
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
def test_error_no_flash_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
def test_not_available_flash(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
)
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
......@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import doctest
import logging
import os
import unittest
from glob import glob
from pathlib import Path
from typing import List, Union
......@@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports inference and training through SDPA for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.")
@require_torch
@require_tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment