Unverified Commit 7c71b61d authored by Phillip Rust's avatar Phillip Rust Committed by GitHub
Browse files

Fix autocast incompatibility in RecurrentGemma (#30832)

parent b275a410
...@@ -254,8 +254,8 @@ class RecurrentGemmaSdpaAttention(nn.Module): ...@@ -254,8 +254,8 @@ class RecurrentGemmaSdpaAttention(nn.Module):
k_out = k_out[:, :, indices] k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices] v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out self.key_states, self.value_states = k_out, v_out
return k_out, v_out return k_out, v_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment