Unverified Commit 57ffd8ab authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`GPT-J`] Fix causal mask dtype (#23147)

* fix #23136

* better fix

* same fix for `masked_bias`
parent 83b38fbe
......@@ -89,8 +89,9 @@ class GPTJAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
......@@ -732,7 +733,7 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment