Unverified Commit 6e1e88fd authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add TFCamembertForCausalLM and ONNX integration test (#16073)

* Make Camembert great again!

* Add Camembert to TensorFlow ONNX tests
parent 20ab1582
...@@ -85,6 +85,10 @@ This model was contributed by [camembert](https://huggingface.co/camembert). The ...@@ -85,6 +85,10 @@ This model was contributed by [camembert](https://huggingface.co/camembert). The
[[autodoc]] TFCamembertModel [[autodoc]] TFCamembertModel
## TFCamembertForCasualLM
[[autodoc]] TFCamembertForCausalLM
## TFCamembertForMaskedLM ## TFCamembertForMaskedLM
[[autodoc]] TFCamembertForMaskedLM [[autodoc]] TFCamembertForMaskedLM
......
...@@ -1744,6 +1744,7 @@ if is_tf_available(): ...@@ -1744,6 +1744,7 @@ if is_tf_available():
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
[ [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCamembertForCausalLM",
"TFCamembertForMaskedLM", "TFCamembertForMaskedLM",
"TFCamembertForMultipleChoice", "TFCamembertForMultipleChoice",
"TFCamembertForQuestionAnswering", "TFCamembertForQuestionAnswering",
...@@ -3812,6 +3813,7 @@ if TYPE_CHECKING: ...@@ -3812,6 +3813,7 @@ if TYPE_CHECKING:
) )
from .models.camembert import ( from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
TFCamembertForMultipleChoice, TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering, TFCamembertForQuestionAnswering,
......
...@@ -139,6 +139,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -139,6 +139,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Causal LM mapping # Model for Causal LM mapping
("camembert", "TFCamembertForCausalLM"),
("rembert", "TFRemBertForCausalLM"), ("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"), ("roformer", "TFRoFormerForCausalLM"),
("roberta", "TFRobertaForCausalLM"), ("roberta", "TFRobertaForCausalLM"),
......
...@@ -52,6 +52,7 @@ if is_torch_available(): ...@@ -52,6 +52,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_camembert"] = [ _import_structure["modeling_tf_camembert"] = [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCamembertForCausalLM",
"TFCamembertForMaskedLM", "TFCamembertForMaskedLM",
"TFCamembertForMultipleChoice", "TFCamembertForMultipleChoice",
"TFCamembertForQuestionAnswering", "TFCamembertForQuestionAnswering",
...@@ -85,6 +86,7 @@ if TYPE_CHECKING: ...@@ -85,6 +86,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_camembert import ( from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
TFCamembertForMultipleChoice, TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering, TFCamembertForQuestionAnswering,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from ...file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
from ..roberta.modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
...@@ -161,3 +162,15 @@ class TFCamembertForQuestionAnswering(TFRobertaForQuestionAnswering): ...@@ -161,3 +162,15 @@ class TFCamembertForQuestionAnswering(TFRobertaForQuestionAnswering):
""" """
config_class = CamembertConfig config_class = CamembertConfig
@add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING
)
class TFCamembertForCausalLM(TFRobertaForCausalLM):
"""
This class overrides [`TFRobertaForCausalLM`]. Please check the superclass for the appropriate documentation
alongside usage examples.
"""
config_class = CamembertConfig
...@@ -537,6 +537,13 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject): ...@@ -537,6 +537,13 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFCamembertForCausalLM(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFCamembertForMaskedLM(metaclass=DummyObject): class TFCamembertForMaskedLM(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -200,6 +200,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ...@@ -200,6 +200,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
TENSORFLOW_EXPORT_DEFAULT_MODELS = { TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"), ("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"), ("bert", "bert-base-cased"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment