"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3e56e2ce04ea3c2f6fa0934bd6d422d8dab17201"
Unverified Commit ede051f1 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix key dtype in GPTJ and CodeGen (#26836)

* fix key dtype in gptj and codegen

* delay the key cast to a later point

* fix
parent 32f799db
...@@ -224,7 +224,9 @@ class CodeGenAttention(nn.Module): ...@@ -224,7 +224,9 @@ class CodeGenAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = (key, value) # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (key.to(hidden_states.dtype), value)
else: else:
present = None present = None
......
...@@ -249,7 +249,9 @@ class GPTJAttention(nn.Module): ...@@ -249,7 +249,9 @@ class GPTJAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = (key, value) # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else: else:
present = None present = None
......
...@@ -306,6 +306,7 @@ class MistralMLP(nn.Module): ...@@ -306,6 +306,7 @@ class MistralMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment