Unverified Commit 80377eb0 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

F.scaled_dot_product_attention support (#26572)



* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ce0bbd51
......@@ -180,6 +180,7 @@ from .import_utils import (
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
......
......@@ -258,6 +258,19 @@ def get_torch_version():
return _torch_version
def is_torch_sdpa_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torchvision_available():
return _torchvision_available
......
......@@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
......@@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......
......@@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
......@@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......
......@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Falcon model. """
import tempfile
import unittest
from parameterized import parameterized
......@@ -26,7 +27,7 @@ from transformers import (
is_torch_available,
set_seed,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
# NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
# for name, submodule in model_eager.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# raise ValueError("The eager model should not have SDPA attention layers")
# has_sdpa = False
# for name, submodule in model_sdpa.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# has_sdpa = True
# break
# if not has_sdpa:
# raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):
......
......@@ -16,11 +16,14 @@
import unittest
from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
......@@ -309,6 +312,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
......@@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
import tempfile
import unittest
import pytest
......@@ -26,6 +27,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
......@@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
......@@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
new_model = LlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_sdpa = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class LlamaIntegrationTest(unittest.TestCase):
......
......@@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
......@@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
......@@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
......@@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase):
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
......
......@@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
......
......@@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
......@@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
......@@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init._attn_implementation = "eager"
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
......@@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
with torch.no_grad():
outputs = model(**inputs)[0]
input_ids = inputs["input_features"]
encoder = model.encoder
encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"]
encoder = model.encoder
if "head_mask" in inputs:
encoder_inputs["head_mask"] = inputs["head_mask"]
if "attention_mask" in inputs:
encoder_inputs["attention_mask"] = inputs["attention_mask"]
if "output_attentions" in inputs:
encoder_inputs["output_attentions"] = inputs["output_attentions"]
with torch.no_grad():
inputs["encoder_outputs"] = encoder(input_ids)
inputs["encoder_outputs"] = encoder(**encoder_inputs)
outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all())
......
......@@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase):
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
missing_keys,
[
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"transformers_version",
],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
......
......@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import gc
......@@ -28,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from parameterized import parameterized
from pytest import mark
import transformers
......@@ -71,6 +71,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)
......@@ -776,102 +777,120 @@ class ModelTesterMixin:
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
continue
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
configs_no_init._attn_implementation = attn_implementation
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
main_input_name = model_class.main_input_name
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input}
if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest("testing SDPA without attention_mask is not supported")
model(main_input, attention_mask=inputs["attention_mask"])
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input)
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -2832,8 +2851,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
import torch
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......@@ -2845,7 +2862,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device)
for _, module in model.named_modules():
......@@ -2859,8 +2876,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -2871,12 +2886,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
......@@ -2956,8 +2971,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -2968,12 +2981,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
......@@ -3049,8 +3062,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_left_padding(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -3060,9 +3071,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
......@@ -3078,7 +3089,10 @@ class ModelTesterMixin:
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
......@@ -3092,8 +3106,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -3103,9 +3115,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
......@@ -3121,7 +3133,10 @@ class ModelTesterMixin:
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
......@@ -3130,13 +3145,330 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(out, out_fa))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_device == "cpu" and torch_dtype == "float16":
self.skipTest("float16 not supported on cpu")
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
......@@ -3163,7 +3495,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
......@@ -3182,8 +3514,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -3204,7 +3534,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
load_in_4bit=True,
)
......@@ -3282,8 +3612,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
......@@ -3291,7 +3619,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config(
config, use_flash_attention_2=True, torch_dtype=torch.bfloat16
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
......@@ -3313,7 +3641,7 @@ class ModelTesterMixin:
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False))
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
fa2_correctly_converted = False
......
......@@ -60,7 +60,13 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
is_torchdynamo_available,
)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
......@@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase):
res_compiled = compiled_model(mask, inputs_embeds)
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
@require_torch
@slow
def test_unmask_unattended_left_padding(self):
attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
expanded_mask = torch.Tensor(
[
[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
reference_output = torch.Tensor(
[
[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
self.assertTrue(torch.equal(result, reference_output))
attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
min_inf = torch.finfo(torch.float32).min
reference_output = torch.Tensor(
[
[
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[min_inf, min_inf, 0, min_inf, min_inf],
[min_inf, min_inf, 0, 0, min_inf],
[min_inf, min_inf, 0, 0, 0],
]
],
[
[
[0, min_inf, min_inf, min_inf, min_inf],
[0, 0, min_inf, min_inf, min_inf],
[0, 0, 0, min_inf, min_inf],
[0, 0, 0, 0, min_inf],
[0, 0, 0, 0, 0],
]
],
[
[
[0, 0, 0, 0, 0],
[min_inf, 0, min_inf, min_inf, min_inf],
[min_inf, 0, 0, min_inf, min_inf],
[min_inf, 0, 0, 0, min_inf],
[min_inf, 0, 0, 0, 0],
]
],
]
)
self.assertTrue(torch.equal(reference_output, result))
@require_torch
@slow
def test_unmask_unattended_right_padding(self):
attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
@slow
def test_unmask_unattended_random_mask(self):
attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
class TestAttentionImplementation(unittest.TestCase):
def test_error_no_sdpa_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
self.assertTrue(
"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
in str(cm.exception)
)
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
def test_error_no_flash_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
def test_not_available_flash(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
)
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
......@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import doctest
import logging
import os
import unittest
from glob import glob
from pathlib import Path
from typing import List, Union
......@@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports inference and training through SDPA for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.")
@require_torch
@require_tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment