Unverified Commit 7509a0ad authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix RecurrentGemma device_map (#30273)

* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
parent 9459efb8
...@@ -252,7 +252,7 @@ class RecurrentGemmaSdpaAttention(nn.Module): ...@@ -252,7 +252,7 @@ class RecurrentGemmaSdpaAttention(nn.Module):
to_shift = cache_position >= self.config.attention_window_size - 1 to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
k_out, v_out = self.key_states, self.value_states k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out = k_out[:, :, indices] k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices] v_out = v_out[:, :, indices]
...@@ -376,7 +376,9 @@ class RecurrentGemmaRglru(nn.Module): ...@@ -376,7 +376,9 @@ class RecurrentGemmaRglru(nn.Module):
return hidden_states, hidden_states[:, 0].type(acc_dtype) return hidden_states, hidden_states[:, 0].type(acc_dtype)
else: else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None] contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
contextualized_states += hidden_states.type(acc_dtype) contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
...@@ -387,7 +389,7 @@ class RecurrentGemmaRglru(nn.Module): ...@@ -387,7 +389,7 @@ class RecurrentGemmaRglru(nn.Module):
contextualized_states = torch.zeros_like(hidden_states) contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]): for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
...@@ -658,7 +660,9 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): ...@@ -658,7 +660,9 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16)) self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment