Unverified Commit ed2ee373 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add TF implementation of GPT-J (#15623)

* Initial commit

* Add TFGPTJModel

* Fix a forward pass

* Add TFGPTJCausalLM

* Add TFGPTJForSequenceClassification

* Add TFGPTJForQuestionAnswering

* Fix docs

* Deal with TF dynamic shapes

* Add Loss parents to models

* Adjust split and merge heads to handle 4 and 5-dim tensors

* Update outputs for @tooslow tests
parent aa4c0a86
...@@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow.
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
| GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ |
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
| GPT-J | ❌ | ❌ | ✅ | | ✅ | | GPT-J | ❌ | ❌ | ✅ | | ✅ |
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ | | ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -130,6 +130,26 @@ model. ...@@ -130,6 +130,26 @@ model.
[[autodoc]] GPTJForQuestionAnswering [[autodoc]] GPTJForQuestionAnswering
- forward - forward
## TFGPTJModel
[[autodoc]] TFGPTJModel
- call
## TFGPTJForCausalLM
[[autodoc]] TFGPTJForCausalLM
- call
## TFGPTJForSequenceClassification
[[autodoc]] TFGPTJForSequenceClassification
- call
## TFGPTJForQuestionAnswering
[[autodoc]] TFGPTJForQuestionAnswering
- call
## FlaxGPTJModel ## FlaxGPTJModel
[[autodoc]] FlaxGPTJModel [[autodoc]] FlaxGPTJModel
......
...@@ -1929,6 +1929,15 @@ if is_tf_available(): ...@@ -1929,6 +1929,15 @@ if is_tf_available():
"TFGPT2PreTrainedModel", "TFGPT2PreTrainedModel",
] ]
) )
_import_structure["models.gptj"].extend(
[
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend( _import_structure["models.hubert"].extend(
[ [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4003,6 +4012,13 @@ if TYPE_CHECKING: ...@@ -4003,6 +4012,13 @@ if TYPE_CHECKING:
TFGPT2Model, TFGPT2Model,
TFGPT2PreTrainedModel, TFGPT2PreTrainedModel,
) )
from .models.gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)
from .models.hubert import ( from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC, TFHubertForCTC,
......
...@@ -52,6 +52,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -52,6 +52,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("bert", "TFBertModel"), ("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"), ("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"), ("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("mobilebert", "TFMobileBertModel"), ("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"), ("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"), ("xlnet", "TFXLNetModel"),
...@@ -123,6 +124,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -123,6 +124,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("bert", "TFBertForMaskedLM"), ("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("mobilebert", "TFMobileBertForMaskedLM"), ("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"), ("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"), ("xlnet", "TFXLNetLMHeadModel"),
...@@ -146,6 +148,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -146,6 +148,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("bert", "TFBertLMHeadModel"), ("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"), ("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"), ("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"), ("xlm", "TFXLMWithLMHeadModel"),
...@@ -239,6 +242,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -239,6 +242,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("tapas", "TFTapasForSequenceClassification"), ("tapas", "TFTapasForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"), ("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"), ("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"), ("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"), ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"), ("transfo-xl", "TFTransfoXLForSequenceClassification"),
...@@ -267,6 +271,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -267,6 +271,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("xlm", "TFXLMForQuestionAnsweringSimple"), ("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"), ("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"), ("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"), ("mpnet", "TFMPNetForQuestionAnswering"),
] ]
) )
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_flax_available, is_torch_available from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -34,6 +34,15 @@ if is_torch_available(): ...@@ -34,6 +34,15 @@ if is_torch_available():
"GPTJPreTrainedModel", "GPTJPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_gptj"] = [
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_gptj"] = [ _import_structure["modeling_flax_gptj"] = [
"FlaxGPTJForCausalLM", "FlaxGPTJForCausalLM",
...@@ -55,6 +64,15 @@ if TYPE_CHECKING: ...@@ -55,6 +64,15 @@ if TYPE_CHECKING:
GPTJPreTrainedModel, GPTJPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)
if is_flax_available(): if is_flax_available():
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
......
This diff is collapsed.
...@@ -1157,6 +1157,41 @@ class TFGPT2PreTrainedModel(metaclass=DummyObject): ...@@ -1157,6 +1157,41 @@ class TFGPT2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFGPTJForCausalLM(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment