Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4557a0de
Unverified
Commit
4557a0de
authored
Nov 02, 2023
by
Michael Benayoun
Committed by
GitHub
Nov 02, 2023
Browse files
Wrap `_prepare_4d_causal_attention_mask` as a leaf function (#27236)
Wrap _prepare_4d_causal_attention_mask as a leaf function
parent
8a312956
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
0 deletions
+7
-0
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+7
-0
No files found.
src/transformers/models/llama/modeling_llama.py
View file @
4557a0de
...
...
@@ -40,6 +40,7 @@ from ...utils import (
logging
,
replace_return_docstrings
,
)
from
...utils.import_utils
import
is_torch_fx_available
from
.configuration_llama
import
LlamaConfig
...
...
@@ -48,6 +49,12 @@ if is_flash_attn_2_available():
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
# noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if
is_torch_fx_available
():
_prepare_4d_causal_attention_mask
=
torch
.
fx
.
wrap
(
_prepare_4d_causal_attention_mask
)
logger
=
logging
.
get_logger
(
__name__
)
_CONFIG_FOR_DOC
=
"LlamaConfig"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment