Unverified Commit 21c912e7 authored by hoshi-hiyouga's avatar hoshi-hiyouga Committed by GitHub
Browse files

Fix config + attn_implementation in AutoModelForCausalLM.from_pretrained (#30299)

* Update modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py
parent b1cd4874
...@@ -3146,7 +3146,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3146,7 +3146,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config = copy.deepcopy(config) config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None) kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: if kwarg_attn_imp is not None:
config._attn_implementation = kwarg_attn_imp config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs model_kwargs = kwargs
......
...@@ -427,6 +427,44 @@ class ModelUtilsTest(TestCasePlus): ...@@ -427,6 +427,44 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto") model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32) self.assertEqual(model.dtype, torch.float32)
def test_model_from_pretrained_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
# 1. explicit from_pretrained's attn_implementation argument
# 2. explicit from_pretrained's attn_implementation argument with a config argument
attn_implementation_available = ["eager"]
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_no_super_init_config_and_model(self): def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32) config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config) model = NoSuperInitModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment