Unverified Commit 57ffd8ab authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`GPT-J`] Fix causal mask dtype (#23147)

* fix #23136

* better fix

* same fix for `masked_bias`
parent 83b38fbe
...@@ -89,8 +89,9 @@ class GPTJAttention(nn.Module): ...@@ -89,8 +89,9 @@ class GPTJAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
persistent=False,
) )
self.register_buffer("masked_bias", torch.tensor(-1e9)) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
...@@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING, GPTJ_START_DOCSTRING,
) )
class GPTJForCausalLM(GPTJPreTrainedModel): class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment