Unverified Commit 9a3f4d4d authored by Damith Senanayake's avatar Damith Senanayake Committed by GitHub
Browse files

Bark model Flash Attention 2 Enabling to pass on check_device_map parameter to super() (#29357)

* Fixing error #29332. The _check_and_enable_flash_attn_2() method receives a check_device_map parameter and fails.

* style fixup
parent 6d67837f
...@@ -1881,6 +1881,7 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1881,6 +1881,7 @@ class BarkModel(BarkPreTrainedModel):
torch_dtype: Optional[torch.dtype] = None, torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None, device_map: Optional[Union[str, Dict[str, int]]] = None,
hard_check_only: bool = False, hard_check_only: bool = False,
check_device_map: bool = False,
): ):
""" """
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
...@@ -1901,7 +1902,7 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1901,7 +1902,7 @@ class BarkModel(BarkPreTrainedModel):
can initialize the correct attention module can initialize the correct attention module
""" """
config = super()._check_and_enable_flash_attn_2( config = super()._check_and_enable_flash_attn_2(
config, torch_dtype, device_map, hard_check_only=hard_check_only config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
) )
config.semantic_config._attn_implementation = config._attn_implementation config.semantic_config._attn_implementation = config._attn_implementation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment