"docs/vscode:/vscode.git/clone" did not exist on "7df4b90c760804295cd4c23a0840055b772898ee"
Unverified Commit dfa7b580 authored by JB (Don)'s avatar JB (Don) Committed by GitHub
Browse files

[`BERT`] Add support for sdpa (#28802)

* Adding SDPA support for BERT

* Using the proper input name for testing model input in inference()

* Adding documentation for SDPA in BERT model page

* Use the stable link for the documentation

* Adding a gate to only call .contiguous() for torch < 2.2.0

* Additions and fixes to the documentation

* Minor updates to documentation

* Adding extra requirements needed for the contiguous() bug

* Adding "Adapted from" in plcae of the "Copied from"

* Add benchmark speedup tables to the documentation

* Minor fixes to the documentation

* Use ClapText as a replacemenet for Bert in the Copied-From

* Some more fixes for the fix-copies references

* Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage

[test all]

* Undo changes to separate test

* Refactored SDPA self attention code for KV projections

* Change use_sdpa to attn_implementation

* Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
parent 2de5cb12
...@@ -245,11 +245,18 @@ class SplinterSelfOutput(nn.Module): ...@@ -245,11 +245,18 @@ class SplinterSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter SPLINTER_SELF_ATTENTION_CLASSES = {
"eager": SplinterSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
class SplinterAttention(nn.Module): class SplinterAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type) self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = SplinterSelfOutput(config) self.output = SplinterSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
......
...@@ -295,11 +295,18 @@ class XLMRobertaSelfOutput(nn.Module): ...@@ -295,11 +295,18 @@ class XLMRobertaSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": XLMRobertaSelfAttention,
}
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA
class XLMRobertaAttention(nn.Module): class XLMRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type) self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = XLMRobertaSelfOutput(config) self.output = XLMRobertaSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -690,7 +697,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): ...@@ -690,7 +697,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
""" """
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -723,7 +730,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): ...@@ -723,7 +730,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Copied from transformers.models.bert.modeling_bert.BertModel.forward # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
......
...@@ -664,7 +664,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): ...@@ -664,7 +664,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
""" """
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRobertaXL # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -697,7 +697,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): ...@@ -697,7 +697,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Copied from transformers.models.bert.modeling_bert.BertModel.forward # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
......
...@@ -783,7 +783,7 @@ class XmodModel(XmodPreTrainedModel): ...@@ -783,7 +783,7 @@ class XmodModel(XmodPreTrainedModel):
""" """
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Xmod # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Xmod
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
......
...@@ -18,7 +18,14 @@ import unittest ...@@ -18,7 +18,14 @@ import unittest
from transformers import BertConfig, is_torch_available from transformers import BertConfig, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device from transformers.testing_utils import (
CaptureLogger,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -621,6 +628,79 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -621,6 +628,79 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
# This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True.
# TODO: Remove this and use the parent method (in common tests) once BERT supports low_cpu_mem_usage=True.
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch @require_torch
class BertModelIntegrationTest(unittest.TestCase): class BertModelIntegrationTest(unittest.TestCase):
......
...@@ -3603,12 +3603,14 @@ class ModelTesterMixin: ...@@ -3603,12 +3603,14 @@ class ModelTesterMixin:
self.assertTrue(model_eager.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules(): for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__: class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False has_sdpa = False
for name, submodule in model_sdpa.named_modules(): for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__: class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True has_sdpa = True
break break
if not has_sdpa and model_sdpa.config.model_type != "falcon": if not has_sdpa and model_sdpa.config.model_type != "falcon":
...@@ -3691,19 +3693,21 @@ class ModelTesterMixin: ...@@ -3691,19 +3693,21 @@ class ModelTesterMixin:
decoder_input_ids = decoder_input_ids.to(torch_device) decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here? # TODO: never an `attention_mask` arg here?
other_inputs = { processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask, "decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True, "output_hidden_states": True,
} }
else: else:
other_inputs = { processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True, "output_hidden_states": True,
} }
# Otherwise fails for e.g. WhisperEncoderModel # Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters: if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask processed_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
...@@ -3712,8 +3716,9 @@ class ModelTesterMixin: ...@@ -3712,8 +3716,9 @@ class ModelTesterMixin:
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
): ):
outputs_eager = model_eager(dummy_input, **other_inputs) prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs) outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = ( logits_eager = (
outputs_eager.hidden_states[-1] outputs_eager.hidden_states[-1]
...@@ -3799,6 +3804,7 @@ class ModelTesterMixin: ...@@ -3799,6 +3804,7 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support SDPA") self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["llava", "llava_next", "vipllava"]: if config.model_type in ["llava", "llava_next", "vipllava"]:
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
if config.model_type in ["idefics"]: if config.model_type in ["idefics"]:
...@@ -3867,12 +3873,14 @@ class ModelTesterMixin: ...@@ -3867,12 +3873,14 @@ class ModelTesterMixin:
self.assertTrue(model_eager.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules(): for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__: class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False has_sdpa = False
for name, submodule in model_sdpa.named_modules(): for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__: class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True has_sdpa = True
break break
if not has_sdpa: if not has_sdpa:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment