Unverified Commit 7e5d46de authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Respect the config's attn_implementation if set (#32383)

* Respect the config's attn if set

* Update test - can override in from_config

* Fix
parent 458b0cd2
...@@ -1454,7 +1454,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1454,7 +1454,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype_orig = cls._set_default_torch_dtype(torch_dtype) dtype_orig = cls._set_default_torch_dtype(torch_dtype)
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config._attn_implementation = kwargs.pop("attn_implementation", None)
if config._attn_implementation_internal is not None:
# In this case, the config has been created with the attn_implementation set by the user, which we
# should respect.
attn_implementation = config._attn_implementation_internal
else:
attn_implementation = None
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
config = cls._autoset_attn_implementation( config = cls._autoset_attn_implementation(
config, config,
use_flash_attention_2=use_flash_attention_2, use_flash_attention_2=use_flash_attention_2,
......
...@@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus): ...@@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus):
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
) )
def test_model_from_config_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
# 1. config created with explicit attn_implementatation and from_config
# 2. explicit from_config's attn_implementation argument with a config argument
# 3. config created with explicit attn_implementatation and from_config overriding with explicit attn_implementation argument
attn_implementation_available = ["eager"]
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
self.assertEqual(config._attn_implementation, requested_attn_implementation)
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
self.assertEqual(config._attn_implementation, "eager")
self.assertEqual(config._attn_implementation_internal, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
self.assertEqual(config._attn_implementation, "foo-bar-baz")
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_torch_dtype_byte_sizes(self): def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [ torch_dtypes_and_bytes = [
(torch.double, 8), (torch.double, 8),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment