Unverified Commit fe36bf5e authored by Canlin Guo's avatar Canlin Guo Committed by GitHub
Browse files

[Model] Remove the unnecessary dtype conversion in MiniCPM (#32523)


Signed-off-by: default avatargcanlin <canlinguosdu@gmail.com>
parent 963dc0b8
...@@ -300,10 +300,7 @@ class MiniCPMAttention(nn.Module): ...@@ -300,10 +300,7 @@ class MiniCPMAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
orig_dtype = q.dtype
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment