Unverified Commit 6e1e88fd authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add TFCamembertForCausalLM and ONNX integration test (#16073)

* Make Camembert great again!

* Add Camembert to TensorFlow ONNX tests
parent 20ab1582
......@@ -85,6 +85,10 @@ This model was contributed by [camembert](https://huggingface.co/camembert). The
[[autodoc]] TFCamembertModel
## TFCamembertForCasualLM
[[autodoc]] TFCamembertForCausalLM
## TFCamembertForMaskedLM
[[autodoc]] TFCamembertForMaskedLM
......
......@@ -1744,6 +1744,7 @@ if is_tf_available():
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCamembertForCausalLM",
"TFCamembertForMaskedLM",
"TFCamembertForMultipleChoice",
"TFCamembertForQuestionAnswering",
......@@ -3812,6 +3813,7 @@ if TYPE_CHECKING:
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
......
......@@ -139,6 +139,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("camembert", "TFCamembertForCausalLM"),
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
......
......@@ -52,6 +52,7 @@ if is_torch_available():
if is_tf_available():
_import_structure["modeling_tf_camembert"] = [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCamembertForCausalLM",
"TFCamembertForMaskedLM",
"TFCamembertForMultipleChoice",
"TFCamembertForQuestionAnswering",
......@@ -85,6 +86,7 @@ if TYPE_CHECKING:
if is_tf_available():
from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
......
......@@ -18,6 +18,7 @@
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_tf_roberta import (
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
......@@ -161,3 +162,15 @@ class TFCamembertForQuestionAnswering(TFRobertaForQuestionAnswering):
"""
config_class = CamembertConfig
@add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING
)
class TFCamembertForCausalLM(TFRobertaForCausalLM):
"""
This class overrides [`TFRobertaForCausalLM`]. Please check the superclass for the appropriate documentation
alongside usage examples.
"""
config_class = CamembertConfig
......@@ -537,6 +537,13 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFCamembertForCausalLM(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFCamembertForMaskedLM(metaclass=DummyObject):
_backends = ["tf"]
......
......@@ -200,6 +200,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment