Unverified Commit 7c71b61d authored by Phillip Rust's avatar Phillip Rust Committed by GitHub
Browse files

Fix autocast incompatibility in RecurrentGemma (#30832)

parent b275a410
......@@ -254,8 +254,8 @@ class RecurrentGemmaSdpaAttention(nn.Module):
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment