Unverified Commit ed2ee373 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add TF implementation of GPT-J (#15623)

* Initial commit

* Add TFGPTJModel

* Fix a forward pass

* Add TFGPTJCausalLM

* Add TFGPTJForSequenceClassification

* Add TFGPTJForQuestionAnswering

* Fix docs

* Deal with TF dynamic shapes

* Add Loss parents to models

* Adjust split and merge heads to handle 4 and 5-dim tensors

* Update outputs for @tooslow tests
parent aa4c0a86
......@@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow.
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
| GLPN | ❌ | ❌ | ✅ | ❌ | ❌ |
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
| GPT-J | ❌ | ❌ | ✅ | | ✅ |
| GPT-J | ❌ | ❌ | ✅ | | ✅ |
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -130,6 +130,26 @@ model.
[[autodoc]] GPTJForQuestionAnswering
- forward
## TFGPTJModel
[[autodoc]] TFGPTJModel
- call
## TFGPTJForCausalLM
[[autodoc]] TFGPTJForCausalLM
- call
## TFGPTJForSequenceClassification
[[autodoc]] TFGPTJForSequenceClassification
- call
## TFGPTJForQuestionAnswering
[[autodoc]] TFGPTJForQuestionAnswering
- call
## FlaxGPTJModel
[[autodoc]] FlaxGPTJModel
......
......@@ -1929,6 +1929,15 @@ if is_tf_available():
"TFGPT2PreTrainedModel",
]
)
_import_structure["models.gptj"].extend(
[
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4003,6 +4012,13 @@ if TYPE_CHECKING:
TFGPT2Model,
TFGPT2PreTrainedModel,
)
from .models.gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)
from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
......
......@@ -52,6 +52,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
......@@ -123,6 +124,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
......@@ -146,6 +148,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
......@@ -239,6 +242,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("tapas", "TFTapasForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
......@@ -267,6 +271,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
]
)
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_flax_available, is_torch_available
from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -34,6 +34,15 @@ if is_torch_available():
"GPTJPreTrainedModel",
]
if is_tf_available():
_import_structure["modeling_tf_gptj"] = [
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
"TFGPTJForSequenceClassification",
"TFGPTJModel",
"TFGPTJPreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_gptj"] = [
"FlaxGPTJForCausalLM",
......@@ -55,6 +64,15 @@ if TYPE_CHECKING:
GPTJPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
TFGPTJForSequenceClassification,
TFGPTJModel,
TFGPTJPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
......
This diff is collapsed.
......@@ -1157,6 +1157,41 @@ class TFGPT2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFGPTJForCausalLM(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGPTJPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment