"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "0e64ab6831a90dbd0101c66a0dba718531bd6c7c"
Unverified Commit ca26699f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`gpt`] Gpt2 fix half precision causal mask (#23256)

* fix gpt2 inference

* fixup

* no need to be in `_keys_to_ignore_on_load_missing`
parent 9088fcae
......@@ -118,8 +118,9 @@ class DecisionTransformerGPT2Attention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -747,6 +748,7 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel):
main_input_name = "states"
supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -131,8 +131,9 @@ class GPT2Attention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -954,7 +955,8 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment