Unverified Commit 07384baf authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

AutoModelForTableQuestionAnswering (#9154)

* AutoModelForTableQuestionAnswering

* Update src/transformers/models/auto/modeling_auto.py

* Style
parent 34334662
......@@ -114,6 +114,13 @@ AutoModelForQuestionAnswering
:members:
AutoModelForTableQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForTableQuestionAnswering
:members:
TFAutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -358,6 +358,7 @@ if is_torch_available():
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......@@ -370,6 +371,7 @@ if is_torch_available():
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
......
......@@ -40,8 +40,8 @@ deps = {
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinx": "sphinx==3.2.1",
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
"tensorflow": "tensorflow>=2.0,<2.4",
"tensorflow-cpu": "tensorflow-cpu>=2.0",
"tensorflow": "tensorflow>=2.0",
"timeout-decorator": "timeout-decorator",
"tokenizers": "tokenizers==0.9.4",
"torch": "torch>=1.0",
......
......@@ -31,6 +31,7 @@ if is_torch_available():
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......@@ -43,6 +44,7 @@ if is_torch_available():
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
......
......@@ -467,6 +467,12 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Table Question Answering mapping
(TapasConfig, TapasForQuestionAnswering),
]
)
......@@ -1384,6 +1390,106 @@ class AutoModelForQuestionAnswering:
)
class AutoModelForTableQuestionAnswering:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a table
question answering head---when created with the when created with the
:meth:`~transformers.AutoModeForTableQuestionAnswering.from_pretrained` class method or the
:meth:`~transformers.AutoModelForTableQuestionAnswering.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoModelForQuestionAnswering is designed to be instantiated "
"using the `AutoModelForTableQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForTableQuestionAnswering.from_config(config)` methods."
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with a table question answering head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.AutoModelForTableQuestionAnswering.from_pretrained` to load
the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('google/tapas-base-finetuned-wtq')
>>> model = AutoModelForTableQuestionAnswering.from_config(config)
"""
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)](config)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
)
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with a table question answering head---from a "
"pretrained model.",
AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq')
>>> # Update configuration during loading
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_json_file('./tf_model/tapas_tf_checkpoint.json')
>>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/tapas_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
)
)
class AutoModelForTokenClassification:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a token
......
......@@ -303,6 +303,9 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
......@@ -393,6 +396,15 @@ class AutoModelForSequenceClassification:
requires_pytorch(self)
class AutoModelForTableQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class AutoModelForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
......
......@@ -17,7 +17,13 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
from transformers.testing_utils import (
DUMMY_UNKWOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
require_scatter,
require_torch,
slow,
)
if is_torch_available():
......@@ -30,6 +36,7 @@ if is_torch_available():
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
BertConfig,
......@@ -44,6 +51,8 @@ if is_torch_available():
RobertaForMaskedLM,
T5Config,
T5ForConditionalGeneration,
TapasConfig,
TapasForQuestionAnswering,
)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
......@@ -52,6 +61,7 @@ if is_torch_available():
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......@@ -59,6 +69,7 @@ if is_torch_available():
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.gpt2.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
@require_torch
......@@ -168,6 +179,21 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForQuestionAnswering)
@slow
@require_scatter
def test_table_question_answering_model_from_pretrained(self):
for model_name in TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, TapasConfig)
model = AutoModelForTableQuestionAnswering.from_pretrained(model_name)
model, loading_info = AutoModelForTableQuestionAnswering.from_pretrained(
model_name, output_loading_info=True
)
self.assertIsNotNone(model)
self.assertIsInstance(model, TapasForQuestionAnswering)
@slow
def test_token_classification_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......@@ -200,6 +226,7 @@ class AutoModelTest(unittest.TestCase):
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......
......@@ -25,9 +25,9 @@ from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
is_torch_available,
)
......@@ -436,7 +436,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment