Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ec92f983
Unverified
Commit
ec92f983
authored
Apr 17, 2024
by
fxmarty
Committed by
GitHub
Apr 17, 2024
Browse files
Fix quality Olmo + SDPA (#30302)
fix olmo
parent
05bdef16
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
9 deletions
+25
-9
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+25
-9
No files found.
src/transformers/models/olmo/modeling_olmo.py
View file @
ec92f983
...
@@ -653,6 +653,7 @@ class OlmoSdpaAttention(OlmoAttention):
...
@@ -653,6 +653,7 @@ class OlmoSdpaAttention(OlmoAttention):
value_states
,
value_states
,
attn_mask
=
causal_mask
,
attn_mask
=
causal_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
@@ -970,9 +971,7 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -970,9 +971,7 @@ class OlmoModel(OlmoPreTrainedModel):
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
causal_mask
=
self
.
_update_causal_mask
(
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_seen_tokens
)
attention_mask
,
inputs_embeds
,
cache_position
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
]
)
# embed positions
# embed positions
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -1036,17 +1035,32 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -1036,17 +1035,32 @@ class OlmoModel(OlmoPreTrainedModel):
attentions
=
all_self_attns
,
attentions
=
all_self_attns
,
)
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def
_update_causal_mask
(
self
,
attention_mask
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
cache_position
:
torch
.
Tensor
,
past_seen_tokens
:
int
,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def
_update_causal_mask
(
self
,
attention_mask
,
input_tensor
,
cache_position
,
current_length
):
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
return
attention_mask
return
attention_mask
return
None
return
None
if
self
.
config
.
_attn_implementation
==
"sdpa"
:
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if
AttentionMaskConverter
.
_ignore_causal_mask_sdpa
(
attention_mask
,
inputs_embeds
=
input_tensor
,
past_key_values_length
=
past_seen_tokens
):
return
None
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
min_dtype
=
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
sequence_length
=
input_tensor
.
shape
[
1
]
sequence_length
=
input_tensor
.
shape
[
1
]
...
@@ -1054,7 +1068,9 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -1054,7 +1068,9 @@ class OlmoModel(OlmoPreTrainedModel):
target_length
=
self
.
config
.
max_position_embeddings
target_length
=
self
.
config
.
max_position_embeddings
else
:
# dynamic cache
else
:
# dynamic cache
target_length
=
(
target_length
=
(
attention_mask
.
shape
[
-
1
]
if
isinstance
(
attention_mask
,
torch
.
Tensor
)
else
current_length
+
1
attention_mask
.
shape
[
-
1
]
if
isinstance
(
attention_mask
,
torch
.
Tensor
)
else
past_seen_tokens
+
sequence_length
+
1
)
)
causal_mask
=
torch
.
full
((
sequence_length
,
target_length
),
fill_value
=
min_dtype
,
dtype
=
dtype
,
device
=
device
)
causal_mask
=
torch
.
full
((
sequence_length
,
target_length
),
fill_value
=
min_dtype
,
dtype
=
dtype
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment