"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d606d566ab0a2635718cb5210d1dad8fae4ce112"
Unverified Commit 78a57c5e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

DBRX: make fixup (#30578)

parent 1bff6a0b
...@@ -1215,6 +1215,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1215,6 +1215,7 @@ class DbrxModel(DbrxPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
...@@ -1227,7 +1228,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1227,7 +1228,10 @@ class DbrxModel(DbrxPreTrainedModel):
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache: if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
): ):
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment