Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ee38fc31
Unverified
Commit
ee38fc31
authored
Mar 21, 2024
by
Joao Gante
Committed by
GitHub
Mar 21, 2024
Browse files
Llama: always convert the causal mask in the SDPA code path (#29663)
* always convert the mask * rebase and fix copies
parent
5ffef2a9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
33 deletions
+12
-33
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+4
-11
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+4
-11
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+4
-11
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
ee38fc31
...
...
@@ -1005,17 +1005,10 @@ class CohereModel(CoherePreTrainedModel):
and
attention_mask
is
not
None
and
attention_mask
.
device
.
type
==
"cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
input_tensor
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
if
not
is_tracing
and
torch
.
any
(
attention_mask
!=
1
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
return
causal_mask
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
ee38fc31
...
...
@@ -1011,17 +1011,10 @@ class GemmaModel(GemmaPreTrainedModel):
and
attention_mask
is
not
None
and
attention_mask
.
device
.
type
==
"cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
input_tensor
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
if
not
is_tracing
and
torch
.
any
(
attention_mask
!=
1
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
return
causal_mask
...
...
src/transformers/models/llama/modeling_llama.py
View file @
ee38fc31
...
...
@@ -1100,17 +1100,10 @@ class LlamaModel(LlamaPreTrainedModel):
and
attention_mask
is
not
None
and
attention_mask
.
device
.
type
==
"cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
input_tensor
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
if
not
is_tracing
and
torch
.
any
(
attention_mask
!=
1
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
return
causal_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment