Unverified Commit f7354a3b authored by Matt's avatar Matt Committed by GitHub
Browse files

Remove token_type_ids from default TF GPT-2 signature (#26962)

Remove token_type_ids from default GPT-2 signature
parent c0b5ad94
...@@ -521,6 +521,16 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): ...@@ -521,6 +521,16 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"] _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"]
@property
def input_signature(self):
# Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation
# means that passing token_type_ids=0 yields different outputs from token_type_ids=None.
# Therefore, we remove the token_type_ids argument by default, even though it would usually be included.
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
@dataclass @dataclass
class TFGPT2DoubleHeadsModelOutput(ModelOutput): class TFGPT2DoubleHeadsModelOutput(ModelOutput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment