Unverified Commit 4557a0de authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Wrap `_prepare_4d_causal_attention_mask` as a leaf function (#27236)

Wrap _prepare_4d_causal_attention_mask as a leaf function
parent 8a312956
...@@ -40,6 +40,7 @@ from ...utils import ( ...@@ -40,6 +40,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.import_utils import is_torch_fx_available
from .configuration_llama import LlamaConfig from .configuration_llama import LlamaConfig
...@@ -48,6 +49,12 @@ if is_flash_attn_2_available(): ...@@ -48,6 +49,12 @@ if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment