Unverified Commit 9b3aab2c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Pickle auto models (#12654)

* PoC, it pickles!

* Remove old method.

* Apply to every auto object
parent 379f6494
...@@ -1938,7 +1938,7 @@ class _LazyModule(ModuleType): ...@@ -1938,7 +1938,7 @@ class _LazyModule(ModuleType):
return importlib.import_module("." + module_name, self.__name__) return importlib.import_module("." + module_name, self.__name__)
def __reduce__(self): def __reduce__(self):
return (self.__class__, (self._name, self._import_structure)) return (self.__class__, (self._name, self.__file__, self._import_structure))
def copy_func(f): def copy_func(f):
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
"""Factory function to build auto-model classes.""" """Factory function to build auto-model classes."""
import types
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func from ...file_utils import copy_func
from ...utils import logging from ...utils import logging
...@@ -401,12 +399,12 @@ def insert_head_doc(docstring, head_doc=""): ...@@ -401,12 +399,12 @@ def insert_head_doc(docstring, head_doc=""):
) )
def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""): def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""):
# Create a new class with the right name from the base class # Create a new class with the right name from the base class
new_class = types.new_class(name, (_BaseAutoModelClass,)) model_mapping = cls._model_mapping
new_class._model_mapping = model_mapping name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name) cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
# have a specific docstrings for them. # have a specific docstrings for them.
...@@ -416,7 +414,7 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca ...@@ -416,7 +414,7 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config) from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
new_class.from_config = classmethod(from_config) cls.from_config = classmethod(from_config)
if name.startswith("TF"): if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
...@@ -432,8 +430,8 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca ...@@ -432,8 +430,8 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained) cls.from_pretrained = classmethod(from_pretrained)
return new_class return cls
def get_values(model_mapping): def get_values(model_mapping):
......
...@@ -308,7 +308,7 @@ from ..xlnet.modeling_xlnet import ( ...@@ -308,7 +308,7 @@ from ..xlnet.modeling_xlnet import (
XLNetLMHeadModel, XLNetLMHeadModel,
XLNetModel, XLNetModel,
) )
from .auto_factory import auto_class_factory from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
BartConfig, BartConfig,
...@@ -780,66 +780,108 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( ...@@ -780,66 +780,108 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
) )
AutoModel = auto_class_factory("AutoModel", MODEL_MAPPING) class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
AutoModel = auto_class_update(AutoModel)
class AutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
AutoModelForPreTraining = auto_class_factory(
"AutoModelForPreTraining", MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
# Private on purpose, the public class will add the deprecation warnings. # Private on purpose, the public class will add the deprecation warnings.
_AutoModelWithLMHead = auto_class_factory( class _AutoModelWithLMHead(_BaseAutoModelClass):
"AutoModelWithLMHead", MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling" _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
)
AutoModelForCausalLM = auto_class_factory(
"AutoModelForCausalLM", MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
)
AutoModelForMaskedLM = auto_class_factory( _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
"AutoModelForMaskedLM", MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)
AutoModelForSeq2SeqLM = auto_class_factory(
"AutoModelForSeq2SeqLM",
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="t5-base",
)
AutoModelForSequenceClassification = auto_class_factory( class AutoModelForCausalLM(_BaseAutoModelClass):
"AutoModelForSequenceClassification", MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, head_doc="sequence classification" _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
) )
AutoModelForQuestionAnswering = auto_class_factory(
"AutoModelForQuestionAnswering", MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
AutoModelForSequenceClassification = auto_class_update(
AutoModelForSequenceClassification, head_doc="sequence classification"
) )
AutoModelForTableQuestionAnswering = auto_class_factory(
"AutoModelForTableQuestionAnswering", class AutoModelForQuestionAnswering(_BaseAutoModelClass):
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering", head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq", checkpoint_for_example="google/tapas-base-finetuned-wtq",
) )
AutoModelForTokenClassification = auto_class_factory(
"AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
)
AutoModelForMultipleChoice = auto_class_factory( class AutoModelForTokenClassification(_BaseAutoModelClass):
"AutoModelForMultipleChoice", MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
)
AutoModelForNextSentencePrediction = auto_class_factory(
"AutoModelForNextSentencePrediction",
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction",
)
AutoModelForImageClassification = auto_class_factory( AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
"AutoModelForImageClassification", MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification"
class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
AutoModelForNextSentencePrediction = auto_class_update(
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
) )
class AutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelWithLMHead(_AutoModelWithLMHead): class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
...@@ -73,7 +73,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -73,7 +73,7 @@ from ..roberta.modeling_flax_roberta import (
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import auto_class_factory from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import ( from .configuration_auto import (
BartConfig, BartConfig,
BertConfig, BertConfig,
...@@ -217,59 +217,89 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( ...@@ -217,59 +217,89 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
] ]
) )
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
FlaxAutoModelForImageClassification = auto_class_factory( class FlaxAutoModel(_BaseAutoModelClass):
"FlaxAutoModelForImageClassification", _model_mapping = FLAX_MODEL_MAPPING
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
head_doc="image classification modeling",
)
FlaxAutoModelForCausalLM = auto_class_factory(
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
)
FlaxAutoModelForPreTraining = auto_class_factory( FlaxAutoModel = auto_class_update(FlaxAutoModel)
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
FlaxAutoModelForMaskedLM = auto_class_factory(
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
FlaxAutoModelForSequenceClassification = auto_class_factory( FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
"FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
head_doc="sequence classification", class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
) _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
FlaxAutoModelForQuestionAnswering = auto_class_factory(
"FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
FlaxAutoModelForSeq2SeqLM = auto_class_update(
FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
) )
FlaxAutoModelForTokenClassification = auto_class_factory(
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
FlaxAutoModelForSequenceClassification = auto_class_update(
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
) )
FlaxAutoModelForMultipleChoice = auto_class_factory(
"AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
FlaxAutoModelForTokenClassification = auto_class_update(
FlaxAutoModelForTokenClassification, head_doc="token classification"
) )
FlaxAutoModelForNextSentencePrediction = auto_class_factory(
"FlaxAutoModelForNextSentencePrediction", class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
head_doc="next sentence prediction",
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
FlaxAutoModelForNextSentencePrediction = auto_class_update(
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
) )
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM", class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
head_doc="sequence-to-sequence language modeling",
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
) )
...@@ -189,7 +189,7 @@ from ..xlnet.modeling_tf_xlnet import ( ...@@ -189,7 +189,7 @@ from ..xlnet.modeling_tf_xlnet import (
TFXLNetLMHeadModel, TFXLNetLMHeadModel,
TFXLNetModel, TFXLNetModel,
) )
from .auto_factory import auto_class_factory from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
BartConfig, BartConfig,
...@@ -487,54 +487,89 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( ...@@ -487,54 +487,89 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
) )
TFAutoModel = auto_class_factory("TFAutoModel", TF_MODEL_MAPPING) class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
TFAutoModelForPreTraining = auto_class_factory(
"TFAutoModelForPreTraining", TF_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
)
# Private on purpose, the public class will add the deprecation warnings. # Private on purpose, the public class will add the deprecation warnings.
_TFAutoModelWithLMHead = auto_class_factory( class _TFAutoModelWithLMHead(_BaseAutoModelClass):
"TFAutoModelWithLMHead", TF_MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling" _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
)
TFAutoModelForCausalLM = auto_class_factory(
"TFAutoModelForCausalLM", TF_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
)
TFAutoModelForMaskedLM = auto_class_factory( _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
"TFAutoModelForMaskedLM", TF_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)
TFAutoModelForSeq2SeqLM = auto_class_factory(
"TFAutoModelForSeq2SeqLM",
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="t5-base",
)
TFAutoModelForSequenceClassification = auto_class_factory( class TFAutoModelForCausalLM(_BaseAutoModelClass):
"TFAutoModelForSequenceClassification", _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
head_doc="sequence classification",
) TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
TFAutoModelForQuestionAnswering = auto_class_factory( class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
"TFAutoModelForQuestionAnswering", TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
) )
TFAutoModelForTokenClassification = auto_class_factory(
"TFAutoModelForTokenClassification", TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
) )
TFAutoModelForMultipleChoice = auto_class_factory(
"TFAutoModelForMultipleChoice", TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
) )
TFAutoModelForNextSentencePrediction = auto_class_factory(
"TFAutoModelForNextSentencePrediction", class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
head_doc="next sentence prediction",
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment