Unverified Commit 838d141f authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Gemma2: fix FA2 generation (#32553)

fix FA2
parent 85817d98
...@@ -1093,7 +1093,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): ...@@ -1093,7 +1093,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
# The clone here is for the same reason as for `position_ids`. # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device device = model_inputs["inputs_embeds"].device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment