Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3f60d11a
Unverified
Commit
3f60d11a
authored
Feb 23, 2024
by
Alessandro Palla
Committed by
GitHub
Feb 23, 2024
Browse files
Improve _update_causal_mask performance (#29210)
* Fix issue 29206 * Fix style
parent
75ed76ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
14 deletions
+8
-14
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+4
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+4
-7
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
3f60d11a
...
@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# We use the current dtype to avoid any overflows
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
min_dtype
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
,
torch
.
finfo
(
dtype
).
min
)
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...
@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
causal_mask
.
min
(),
dim
=-
1
,
keepdim
=
True
)).
to
(
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
min_dtype
,
dim
=-
1
,
keepdim
=
True
)).
to
(
dtype
)
dtype
)
return
causal_mask
return
causal_mask
...
...
src/transformers/models/llama/modeling_llama.py
View file @
3f60d11a
...
@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# We use the current dtype to avoid any overflows
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
min_dtype
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
,
torch
.
finfo
(
dtype
).
min
)
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...
@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
causal_mask
.
min
(),
dim
=-
1
,
keepdim
=
True
)).
to
(
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
min_dtype
,
dim
=-
1
,
keepdim
=
True
)).
to
(
dtype
)
dtype
)
return
causal_mask
return
causal_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment