Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
78a57c5e
Unverified
Commit
78a57c5e
authored
Apr 30, 2024
by
Joao Gante
Committed by
GitHub
Apr 30, 2024
Browse files
DBRX: make fixup (#30578)
parent
1bff6a0b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+5
-1
No files found.
src/transformers/models/dbrx/modeling_dbrx.py
View file @
78a57c5e
...
@@ -1215,6 +1215,7 @@ class DbrxModel(DbrxPreTrainedModel):
...
@@ -1215,6 +1215,7 @@ class DbrxModel(DbrxPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
return
attention_mask
return
attention_mask
...
@@ -1227,7 +1228,10 @@ class DbrxModel(DbrxPreTrainedModel):
...
@@ -1227,7 +1228,10 @@ class DbrxModel(DbrxPreTrainedModel):
using_static_cache
=
isinstance
(
past_key_values
,
StaticCache
)
using_static_cache
=
isinstance
(
past_key_values
,
StaticCache
)
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
not
using_static_cache
:
if
self
.
config
.
_attn_implementation
==
"sdpa"
and
not
using_static_cache
:
if
AttentionMaskConverter
.
_ignore_causal_mask_sdpa
(
if
AttentionMaskConverter
.
_ignore_causal_mask_sdpa
(
attention_mask
,
inputs_embeds
=
input_tensor
,
past_key_values_length
=
past_seen_tokens
attention_mask
,
inputs_embeds
=
input_tensor
,
past_key_values_length
=
past_seen_tokens
,
is_training
=
self
.
training
,
):
):
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment