Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9b3aab2c
Unverified
Commit
9b3aab2c
authored
Jul 12, 2021
by
Sylvain Gugger
Committed by
GitHub
Jul 12, 2021
Browse files
Pickle auto models (#12654)
* PoC, it pickles! * Remove old method. * Apply to every auto object
parent
379f6494
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
230 additions
and
125 deletions
+230
-125
src/transformers/file_utils.py
src/transformers/file_utils.py
+1
-1
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+7
-9
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+82
-40
src/transformers/models/auto/modeling_flax_auto.py
src/transformers/models/auto/modeling_flax_auto.py
+70
-40
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+70
-35
No files found.
src/transformers/file_utils.py
View file @
9b3aab2c
...
...
@@ -1938,7 +1938,7 @@ class _LazyModule(ModuleType):
return
importlib
.
import_module
(
"."
+
module_name
,
self
.
__name__
)
def
__reduce__
(
self
):
return
(
self
.
__class__
,
(
self
.
_name
,
self
.
_import_structure
))
return
(
self
.
__class__
,
(
self
.
_name
,
self
.
__file__
,
self
.
_import_structure
))
def
copy_func
(
f
):
...
...
src/transformers/models/auto/auto_factory.py
View file @
9b3aab2c
...
...
@@ -14,8 +14,6 @@
# limitations under the License.
"""Factory function to build auto-model classes."""
import
types
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
copy_func
from
...utils
import
logging
...
...
@@ -401,12 +399,12 @@ def insert_head_doc(docstring, head_doc=""):
)
def
auto_class_
factory
(
name
,
model_mapping
,
checkpoint_for_example
=
"bert-base-cased"
,
head_doc
=
""
):
def
auto_class_
update
(
cls
,
checkpoint_for_example
=
"bert-base-cased"
,
head_doc
=
""
):
# Create a new class with the right name from the base class
new_class
=
types
.
new_class
(
name
,
(
_BaseAutoModelClass
,))
n
ew_class
.
_model_mapping
=
model_mapping
model_mapping
=
cls
.
_model_mapping
n
ame
=
cls
.
__name__
class_docstring
=
insert_head_doc
(
CLASS_DOCSTRING
,
head_doc
=
head_doc
)
new_clas
s
.
__doc__
=
class_docstring
.
replace
(
"BaseAutoModelClass"
,
name
)
cl
s
.
__doc__
=
class_docstring
.
replace
(
"BaseAutoModelClass"
,
name
)
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
# have a specific docstrings for them.
...
...
@@ -416,7 +414,7 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_config_docstring
=
from_config_docstring
.
replace
(
"checkpoint_placeholder"
,
checkpoint_for_example
)
from_config
.
__doc__
=
from_config_docstring
from_config
=
replace_list_option_in_docstrings
(
model_mapping
,
use_model_types
=
False
)(
from_config
)
new_clas
s
.
from_config
=
classmethod
(
from_config
)
cl
s
.
from_config
=
classmethod
(
from_config
)
if
name
.
startswith
(
"TF"
):
from_pretrained_docstring
=
FROM_PRETRAINED_TF_DOCSTRING
...
...
@@ -432,8 +430,8 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained_docstring
=
from_pretrained_docstring
.
replace
(
"shortcut_placeholder"
,
shortcut
)
from_pretrained
.
__doc__
=
from_pretrained_docstring
from_pretrained
=
replace_list_option_in_docstrings
(
model_mapping
)(
from_pretrained
)
new_clas
s
.
from_pretrained
=
classmethod
(
from_pretrained
)
return
new_clas
s
cl
s
.
from_pretrained
=
classmethod
(
from_pretrained
)
return
cl
s
def
get_values
(
model_mapping
):
...
...
src/transformers/models/auto/modeling_auto.py
View file @
9b3aab2c
...
...
@@ -308,7 +308,7 @@ from ..xlnet.modeling_xlnet import (
XLNetLMHeadModel
,
XLNetModel
,
)
from
.auto_factory
import
auto_class_
factory
from
.auto_factory
import
_BaseAutoModelClass
,
auto_class_
update
from
.configuration_auto
import
(
AlbertConfig
,
BartConfig
,
...
...
@@ -780,66 +780,108 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
)
AutoModel
=
auto_class_factory
(
"AutoModel"
,
MODEL_MAPPING
)
class
AutoModel
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_MAPPING
AutoModel
=
auto_class_update
(
AutoModel
)
class
AutoModelForPreTraining
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_PRETRAINING_MAPPING
AutoModelForPreTraining
=
auto_class_update
(
AutoModelForPreTraining
,
head_doc
=
"pretraining"
)
AutoModelForPreTraining
=
auto_class_factory
(
"AutoModelForPreTraining"
,
MODEL_FOR_PRETRAINING_MAPPING
,
head_doc
=
"pretraining"
)
# Private on purpose, the public class will add the deprecation warnings.
_AutoModelWithLMHead
=
auto_class_factory
(
"AutoModelWithLMHead"
,
MODEL_WITH_LM_HEAD_MAPPING
,
head_doc
=
"language modeling"
)
class
_AutoModelWithLMHead
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_WITH_LM_HEAD_MAPPING
AutoModelForCausalLM
=
auto_class_factory
(
"AutoModelForCausalLM"
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
head_doc
=
"causal language modeling"
)
AutoModelForMaskedLM
=
auto_class_factory
(
"AutoModelForMaskedLM"
,
MODEL_FOR_MASKED_LM_MAPPING
,
head_doc
=
"masked language modeling"
)
_AutoModelWithLMHead
=
auto_class_update
(
_AutoModelWithLMHead
,
head_doc
=
"language modeling"
)
AutoModelForSeq2SeqLM
=
auto_class_factory
(
"AutoModelForSeq2SeqLM"
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
head_doc
=
"sequence-to-sequence language modeling"
,
checkpoint_for_example
=
"t5-base"
,
)
AutoModelForSequenceClassification
=
auto_class_factory
(
"AutoModelForSequenceClassification"
,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
head_doc
=
"sequence classification"
class
AutoModelForCausalLM
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM
=
auto_class_update
(
AutoModelForCausalLM
,
head_doc
=
"causal language modeling"
)
class
AutoModelForMaskedLM
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_MASKED_LM_MAPPING
AutoModelForMaskedLM
=
auto_class_update
(
AutoModelForMaskedLM
,
head_doc
=
"masked language modeling"
)
class
AutoModelForSeq2SeqLM
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
AutoModelForSeq2SeqLM
=
auto_class_update
(
AutoModelForSeq2SeqLM
,
head_doc
=
"sequence-to-sequence language modeling"
,
checkpoint_for_example
=
"t5-base"
)
AutoModelForQuestionAnswering
=
auto_class_factory
(
"AutoModelForQuestionAnswering"
,
MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
head_doc
=
"question answering"
class
AutoModelForSequenceClassification
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
AutoModelForSequenceClassification
=
auto_class_update
(
AutoModelForSequenceClassification
,
head_doc
=
"sequence classification"
)
AutoModelForTableQuestionAnswering
=
auto_class_factory
(
"AutoModelForTableQuestionAnswering"
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
class
AutoModelForQuestionAnswering
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_QUESTION_ANSWERING_MAPPING
AutoModelForQuestionAnswering
=
auto_class_update
(
AutoModelForQuestionAnswering
,
head_doc
=
"question answering"
)
class
AutoModelForTableQuestionAnswering
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
AutoModelForTableQuestionAnswering
=
auto_class_update
(
AutoModelForTableQuestionAnswering
,
head_doc
=
"table question answering"
,
checkpoint_for_example
=
"google/tapas-base-finetuned-wtq"
,
)
AutoModelForTokenClassification
=
auto_class_factory
(
"AutoModelForTokenClassification"
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
head_doc
=
"token classification"
)
AutoModelForMultipleChoice
=
auto_class_factory
(
"AutoModelForMultipleChoice"
,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
head_doc
=
"multiple choice"
)
class
AutoModelForTokenClassification
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
AutoModelForNextSentencePrediction
=
auto_class_factory
(
"AutoModelForNextSentencePrediction"
,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
,
head_doc
=
"next sentence prediction"
,
)
AutoModelForImageClassification
=
auto_class_factory
(
"AutoModelForImageClassification"
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
head_doc
=
"image classification"
AutoModelForTokenClassification
=
auto_class_update
(
AutoModelForTokenClassification
,
head_doc
=
"token classification"
)
class
AutoModelForMultipleChoice
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
AutoModelForMultipleChoice
=
auto_class_update
(
AutoModelForMultipleChoice
,
head_doc
=
"multiple choice"
)
class
AutoModelForNextSentencePrediction
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
AutoModelForNextSentencePrediction
=
auto_class_update
(
AutoModelForNextSentencePrediction
,
head_doc
=
"next sentence prediction"
)
class
AutoModelForImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification
=
auto_class_update
(
AutoModelForImageClassification
,
head_doc
=
"image classification"
)
class
AutoModelWithLMHead
(
_AutoModelWithLMHead
):
@
classmethod
def
from_config
(
cls
,
config
):
...
...
src/transformers/models/auto/modeling_flax_auto.py
View file @
9b3aab2c
...
...
@@ -73,7 +73,7 @@ from ..roberta.modeling_flax_roberta import (
from
..t5.modeling_flax_t5
import
FlaxT5ForConditionalGeneration
,
FlaxT5Model
from
..vit.modeling_flax_vit
import
FlaxViTForImageClassification
,
FlaxViTModel
from
..wav2vec2.modeling_flax_wav2vec2
import
FlaxWav2Vec2ForPreTraining
,
FlaxWav2Vec2Model
from
.auto_factory
import
auto_class_
factory
from
.auto_factory
import
_BaseAutoModelClass
,
auto_class_
update
from
.configuration_auto
import
(
BartConfig
,
BertConfig
,
...
...
@@ -217,59 +217,89 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
]
)
FlaxAutoModel
=
auto_class_factory
(
"FlaxAutoModel"
,
FLAX_MODEL_MAPPING
)
FlaxAutoModelForImageClassification
=
auto_class_factory
(
"FlaxAutoModelForImageClassification"
,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
head_doc
=
"image classification modeling"
,
)
class
FlaxAutoModel
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_MAPPING
FlaxAutoModelForCausalLM
=
auto_class_factory
(
"FlaxAutoModelForCausalLM"
,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
head_doc
=
"causal language modeling"
)
FlaxAutoModelForPreTraining
=
auto_class_factory
(
"FlaxAutoModelForPreTraining"
,
FLAX_MODEL_FOR_PRETRAINING_MAPPING
,
head_doc
=
"pretraining"
)
FlaxAutoModel
=
auto_class_update
(
FlaxAutoModel
)
FlaxAutoModelForMaskedLM
=
auto_class_factory
(
"FlaxAutoModelForMaskedLM"
,
FLAX_MODEL_FOR_MASKED_LM_MAPPING
,
head_doc
=
"masked language modeling"
)
class
FlaxAutoModelForPreTraining
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_PRETRAINING_MAPPING
FlaxAutoModelForSeq2SeqLM
=
auto_class_factory
(
"FlaxAutoModelForSeq2SeqLM"
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
head_doc
=
"sequence-to-sequence language modeling"
,
)
FlaxAutoModelForSequenceClassification
=
auto_class_factory
(
"FlaxAutoModelForSequenceClassification"
,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
head_doc
=
"sequence classification"
,
)
FlaxAutoModelForPreTraining
=
auto_class_update
(
FlaxAutoModelForPreTraining
,
head_doc
=
"pretraining"
)
class
FlaxAutoModelForCausalLM
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
FlaxAutoModelForCausalLM
=
auto_class_update
(
FlaxAutoModelForCausalLM
,
head_doc
=
"causal language modeling"
)
FlaxAutoModelForQuestionAnswering
=
auto_class_factory
(
"FlaxAutoModelForQuestionAnswering"
,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
head_doc
=
"question answering"
class
FlaxAutoModelForMaskedLM
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_MASKED_LM_MAPPING
FlaxAutoModelForMaskedLM
=
auto_class_update
(
FlaxAutoModelForMaskedLM
,
head_doc
=
"masked language modeling"
)
class
FlaxAutoModelForSeq2SeqLM
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
FlaxAutoModelForSeq2SeqLM
=
auto_class_update
(
FlaxAutoModelForSeq2SeqLM
,
head_doc
=
"sequence-to-sequence language modeling"
,
checkpoint_for_example
=
"t5-base"
)
FlaxAutoModelForTokenClassification
=
auto_class_factory
(
"FlaxAutoModelForTokenClassification"
,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
head_doc
=
"token classification"
class
FlaxAutoModelForSequenceClassification
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
FlaxAutoModelForSequenceClassification
=
auto_class_update
(
FlaxAutoModelForSequenceClassification
,
head_doc
=
"sequence classification"
)
FlaxAutoModelForMultipleChoice
=
auto_class_factory
(
"AutoModelForMultipleChoice"
,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
head_doc
=
"multiple choice"
class
FlaxAutoModelForQuestionAnswering
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
FlaxAutoModelForQuestionAnswering
=
auto_class_update
(
FlaxAutoModelForQuestionAnswering
,
head_doc
=
"question answering"
)
class
FlaxAutoModelForTokenClassification
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
FlaxAutoModelForTokenClassification
=
auto_class_update
(
FlaxAutoModelForTokenClassification
,
head_doc
=
"token classification"
)
FlaxAutoModelForNextSentencePrediction
=
auto_class_factory
(
"FlaxAutoModelForNextSentencePrediction"
,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
,
head_doc
=
"next sentence prediction"
,
class
FlaxAutoModelForMultipleChoice
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
FlaxAutoModelForMultipleChoice
=
auto_class_update
(
FlaxAutoModelForMultipleChoice
,
head_doc
=
"multiple choice"
)
class
FlaxAutoModelForNextSentencePrediction
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
FlaxAutoModelForNextSentencePrediction
=
auto_class_update
(
FlaxAutoModelForNextSentencePrediction
,
head_doc
=
"next sentence prediction"
)
FlaxAutoModelForSeq2SeqLM
=
auto_class_factory
(
"FlaxAutoModelForSeq2SeqLM"
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
head_doc
=
"sequence-to-sequence language modeling"
,
class
FlaxAutoModelForImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
FlaxAutoModelForImageClassification
=
auto_class_update
(
FlaxAutoModelForImageClassification
,
head_doc
=
"image classification"
)
src/transformers/models/auto/modeling_tf_auto.py
View file @
9b3aab2c
...
...
@@ -189,7 +189,7 @@ from ..xlnet.modeling_tf_xlnet import (
TFXLNetLMHeadModel
,
TFXLNetModel
,
)
from
.auto_factory
import
auto_class_
factory
from
.auto_factory
import
_BaseAutoModelClass
,
auto_class_
update
from
.configuration_auto
import
(
AlbertConfig
,
BartConfig
,
...
...
@@ -487,54 +487,89 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
)
TFAutoModel
=
auto_class_factory
(
"TFAutoModel"
,
TF_MODEL_MAPPING
)
class
TFAutoModel
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_MAPPING
TFAutoModel
=
auto_class_update
(
TFAutoModel
)
class
TFAutoModelForPreTraining
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining
=
auto_class_update
(
TFAutoModelForPreTraining
,
head_doc
=
"pretraining"
)
TFAutoModelForPreTraining
=
auto_class_factory
(
"TFAutoModelForPreTraining"
,
TF_MODEL_FOR_PRETRAINING_MAPPING
,
head_doc
=
"pretraining"
)
# Private on purpose, the public class will add the deprecation warnings.
_TFAutoModelWithLMHead
=
auto_class_factory
(
"TFAutoModelWithLMHead"
,
TF_MODEL_WITH_LM_HEAD_MAPPING
,
head_doc
=
"language modeling"
)
class
_TFAutoModelWithLMHead
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_WITH_LM_HEAD_MAPPING
TFAutoModelForCausalLM
=
auto_class_factory
(
"TFAutoModelForCausalLM"
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
head_doc
=
"causal language modeling"
)
TFAutoModelForMaskedLM
=
auto_class_factory
(
"TFAutoModelForMaskedLM"
,
TF_MODEL_FOR_MASKED_LM_MAPPING
,
head_doc
=
"masked language modeling"
)
_TFAutoModelWithLMHead
=
auto_class_update
(
_TFAutoModelWithLMHead
,
head_doc
=
"language modeling"
)
TFAutoModelForSeq2SeqLM
=
auto_class_factory
(
"TFAutoModelForSeq2SeqLM"
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
head_doc
=
"sequence-to-sequence language modeling"
,
checkpoint_for_example
=
"t5-base"
,
)
TFAutoModelForSequenceClassification
=
auto_class_factory
(
"TFAutoModelForSequenceClassification"
,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
head_doc
=
"sequence classification"
,
)
class
TFAutoModelForCausalLM
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM
=
auto_class_update
(
TFAutoModelForCausalLM
,
head_doc
=
"causal language modeling"
)
class
TFAutoModelForMaskedLM
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM
=
auto_class_update
(
TFAutoModelForMaskedLM
,
head_doc
=
"masked language modeling"
)
TFAutoModelForQuestionAnswering
=
auto_class_factory
(
"TFAutoModelForQuestionAnswering"
,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
head_doc
=
"question answering"
class
TFAutoModelForSeq2SeqLM
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM
=
auto_class_update
(
TFAutoModelForSeq2SeqLM
,
head_doc
=
"sequence-to-sequence language modeling"
,
checkpoint_for_example
=
"t5-base"
)
TFAutoModelForTokenClassification
=
auto_class_factory
(
"TFAutoModelForTokenClassification"
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
head_doc
=
"token classification"
class
TFAutoModelForSequenceClassification
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification
=
auto_class_update
(
TFAutoModelForSequenceClassification
,
head_doc
=
"sequence classification"
)
TFAutoModelForMultipleChoice
=
auto_class_factory
(
"TFAutoModelForMultipleChoice"
,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
head_doc
=
"multiple choice"
class
TFAutoModelForQuestionAnswering
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering
=
auto_class_update
(
TFAutoModelForQuestionAnswering
,
head_doc
=
"question answering"
)
class
TFAutoModelForTokenClassification
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification
=
auto_class_update
(
TFAutoModelForTokenClassification
,
head_doc
=
"token classification"
)
TFAutoModelForNextSentencePrediction
=
auto_class_factory
(
"TFAutoModelForNextSentencePrediction"
,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
,
head_doc
=
"next sentence prediction"
,
class
TFAutoModelForMultipleChoice
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
TFAutoModelForMultipleChoice
=
auto_class_update
(
TFAutoModelForMultipleChoice
,
head_doc
=
"multiple choice"
)
class
TFAutoModelForNextSentencePrediction
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction
=
auto_class_update
(
TFAutoModelForNextSentencePrediction
,
head_doc
=
"next sentence prediction"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment