Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3f60d11a
Unverified
Commit
3f60d11a
authored
Feb 23, 2024
by
Alessandro Palla
Committed by
GitHub
Feb 23, 2024
Browse files
Improve _update_causal_mask performance (#29210)
* Fix issue 29206 * Fix style
parent
75ed76ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
14 deletions
+8
-14
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+4
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+4
-7
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
3f60d11a
...
@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -959,15 +959,14 @@ class GemmaModel(GemmaPreTrainedModel):
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# We use the current dtype to avoid any overflows
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
min_dtype
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
,
torch
.
finfo
(
dtype
).
min
)
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...
@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -980,9 +979,7 @@ class GemmaModel(GemmaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
causal_mask
.
min
(),
dim
=-
1
,
keepdim
=
True
)).
to
(
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
min_dtype
,
dim
=-
1
,
keepdim
=
True
)).
to
(
dtype
)
dtype
)
return
causal_mask
return
causal_mask
...
...
src/transformers/models/llama/modeling_llama.py
View file @
3f60d11a
...
@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1066,15 +1066,14 @@ class LlamaModel(LlamaPreTrainedModel):
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# We use the current dtype to avoid any overflows
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
min_dtype
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
padding_mask
,
torch
.
finfo
(
dtype
).
min
)
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
...
@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1087,9 +1086,7 @@ class LlamaModel(LlamaPreTrainedModel):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
causal_mask
.
min
(),
dim
=-
1
,
keepdim
=
True
)).
to
(
causal_mask
=
causal_mask
.
mul
(
~
torch
.
all
(
causal_mask
==
min_dtype
,
dim
=-
1
,
keepdim
=
True
)).
to
(
dtype
)
dtype
)
return
causal_mask
return
causal_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment