Unverified Commit 8b169142 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`GPT2`] Add correct keys on `_keys_to_ignore_on_load_unexpected` on all child...

[`GPT2`] Add correct keys on `_keys_to_ignore_on_load_unexpected` on all child classes of `GPT2PreTrainedModel` (#24113)

* add correct keys on `_keys_to_ignore_on_load_unexpected`

* oops
parent 71a114d3
...@@ -668,7 +668,8 @@ DEPARALLELIZE_DOCSTRING = r""" ...@@ -668,7 +668,8 @@ DEPARALLELIZE_DOCSTRING = r"""
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2Model(GPT2PreTrainedModel): class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1149,6 +1150,7 @@ input sequence). ...@@ -1149,6 +1150,7 @@ input sequence).
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2DoubleHeadsModel(GPT2PreTrainedModel): class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1377,6 +1379,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1377,6 +1379,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2ForSequenceClassification(GPT2PreTrainedModel): class GPT2ForSequenceClassification(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1600,6 +1603,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel): ...@@ -1600,6 +1603,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2ForQuestionAnswering(GPT2PreTrainedModel): class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment