"src/include/utility.hpp" did not exist on "b2888adfbe103ae3d9006af87d5871b69cbf00ba"
Unverified Commit 026a2ff2 authored by Ratthachat (Jung)'s avatar Ratthachat (Jung) Committed by GitHub
Browse files

Add TFDPR (#8203)

* Create modeling_tf_dpr.py

* Add TFDPR

* Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot

last commit accidentally deleted these 4 lines, so I recover them back

* Add TFDPR

* Add TFDPR

* clean up some comments, add TF input-style doc string

* Add TFDPR

* Make return_dict=False as default

* Fix return_dict bug (in .from_pretrained)

* Add get_input_embeddings()

* Create test_modeling_tf_dpr.py

The current version is already passed all 27 tests!
Please see the test run at : 
https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing



* fix quality

* delete init weights

* run fix copies

* fix repo consis

* del config_class, load_tf_weights

They shoud be 'pytorch only'

* add config_class back

after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion

* newline after .. note::

* import tf, np (Necessary for ModelIntegrationTest)

* slow_test from_pretrained with from_pt=True

At the moment we don't have TF weights (since we don't have official official TF model)
Previously, I did not run slow test, so I missed this bug

* Add simple TFDPRModelIntegrationTest

Note that this is just a test that TF and Pytorch gives approx. the same output.
However, I could not test with the official DPR repo's output yet

* upload correct tf model

* remove position_ids as missing keys
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarpatrickvonplaten <patrick@huggingface.co>
parent a38d1c7c
...@@ -99,3 +99,22 @@ DPRReader ...@@ -99,3 +99,22 @@ DPRReader
.. autoclass:: transformers.DPRReader .. autoclass:: transformers.DPRReader
:members: forward :members: forward
TFDPRContextEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFDPRContextEncoder
:members: call
TFDPRQuestionEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFDPRQuestionEncoder
:members: call
TFDPRReader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFDPRReader
:members: call
...@@ -406,6 +406,9 @@ if is_torch_available(): ...@@ -406,6 +406,9 @@ if is_torch_available():
DistilBertPreTrainedModel, DistilBertPreTrainedModel,
) )
from .modeling_dpr import ( from .modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder, DPRContextEncoder,
DPRPretrainedContextEncoder, DPRPretrainedContextEncoder,
DPRPretrainedQuestionEncoder, DPRPretrainedQuestionEncoder,
...@@ -713,6 +716,17 @@ if is_tf_available(): ...@@ -713,6 +716,17 @@ if is_tf_available():
TFDistilBertModel, TFDistilBertModel,
TFDistilBertPreTrainedModel, TFDistilBertPreTrainedModel,
) )
from .modeling_tf_dpr import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDPRContextEncoder,
TFDPRPretrainedContextEncoder,
TFDPRPretrainedQuestionEncoder,
TFDPRPretrainedReader,
TFDPRQuestionEncoder,
TFDPRReader,
)
from .modeling_tf_electra import ( from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM, TFElectraForMaskedLM,
......
...@@ -25,6 +25,9 @@ from transformers import ( ...@@ -25,6 +25,9 @@ from transformers import (
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -43,6 +46,7 @@ from transformers import ( ...@@ -43,6 +46,7 @@ from transformers import (
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DistilBertConfig, DistilBertConfig,
DPRConfig,
ElectraConfig, ElectraConfig,
FlaubertConfig, FlaubertConfig,
GPT2Config, GPT2Config,
...@@ -59,6 +63,9 @@ from transformers import ( ...@@ -59,6 +63,9 @@ from transformers import (
TFCTRLLMHeadModel, TFCTRLLMHeadModel,
TFDistilBertForMaskedLM, TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
TFElectraForPreTraining, TFElectraForPreTraining,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel, TFGPT2LMHeadModel,
...@@ -98,6 +105,9 @@ if is_torch_available(): ...@@ -98,6 +105,9 @@ if is_torch_available():
CTRLLMHeadModel, CTRLLMHeadModel,
DistilBertForMaskedLM, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForQuestionAnswering,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
ElectraForPreTraining, ElectraForPreTraining,
FlaubertWithLMHeadModel, FlaubertWithLMHeadModel,
GPT2LMHeadModel, GPT2LMHeadModel,
...@@ -147,6 +157,18 @@ MODEL_CLASSES = { ...@@ -147,6 +157,18 @@ MODEL_CLASSES = {
BertForSequenceClassification, BertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
), ),
"dpr": (
DPRConfig,
TFDPRQuestionEncoder,
TFDPRContextEncoder,
TFDPRReader,
DPRQuestionEncoder,
DPRContextEncoder,
DPRReader,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"gpt2": ( "gpt2": (
GPT2Config, GPT2Config,
TFGPT2LMHeadModel, TFGPT2LMHeadModel,
......
...@@ -43,6 +43,7 @@ from .configuration_auto import ( ...@@ -43,6 +43,7 @@ from .configuration_auto import (
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
from .configuration_dpr import DPRConfig
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
from .configuration_mbart import MBartConfig from .configuration_mbart import MBartConfig
from .configuration_pegasus import PegasusConfig from .configuration_pegasus import PegasusConfig
...@@ -87,6 +88,7 @@ from .modeling_tf_distilbert import ( ...@@ -87,6 +88,7 @@ from .modeling_tf_distilbert import (
TFDistilBertForTokenClassification, TFDistilBertForTokenClassification,
TFDistilBertModel, TFDistilBertModel,
) )
from .modeling_tf_dpr import TFDPRQuestionEncoder
from .modeling_tf_electra import ( from .modeling_tf_electra import (
TFElectraForMaskedLM, TFElectraForMaskedLM,
TFElectraForMultipleChoice, TFElectraForMultipleChoice,
...@@ -192,6 +194,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -192,6 +194,7 @@ TF_MODEL_MAPPING = OrderedDict(
(CTRLConfig, TFCTRLModel), (CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel), (ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel), (FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder),
] ]
) )
......
# coding=utf-8
# Copyright 2018 DPR Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow DPR model for Open Domain Question Answering."""
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense
from .configuration_dpr import DPRConfig
from .file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_tf_bert import TFBertMainLayer
from .modeling_tf_outputs import TFBaseModelOutputWithPooling
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DPRConfig"
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-ctx_encoder-single-nq-base",
"facebook/dpr-ctx_encoder-multiset-base",
]
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-question_encoder-single-nq-base",
"facebook/dpr-question_encoder-multiset-base",
]
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-reader-single-nq-base",
"facebook/dpr-reader-multiset-base",
]
##########
# Outputs
##########
@dataclass
class TFDPRContextEncoderOutput(ModelOutput):
r"""
Class for outputs of :class:`~transformers.TFDPRContextEncoder`.
Args:
pooler_output: (:obj:``tf.Tensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the context representation. Last layer
hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: tf.Tensor
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFDPRQuestionEncoderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.TFDPRQuestionEncoder`.
Args:
pooler_output: (:obj:``tf.Tensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the question representation. Last layer
hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
This output is to be used to embed questions for nearest neighbors queries with context embeddings.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: tf.Tensor
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFDPRReaderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.TFDPRReaderEncoder`.
Args:
start_logits: (:obj:``tf.Tensor`` of shape ``(n_passages, sequence_length)``):
Logits of the start index of the span for each passage.
end_logits: (:obj:``tf.Tensor`` of shape ``(n_passages, sequence_length)``):
Logits of the end index of the span for each passage.
relevance_logits: (:obj:`tf.Tensor`` of shape ``(n_passages, )``):
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
question, compared to all the other passages.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: tf.Tensor
end_logits: tf.Tensor = None
relevance_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
class TFDPREncoder(TFPreTrainedModel):
base_model_prefix = "bert_model"
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
# resolve name conflict with TFBertMainLayer instead of TFBertModel
self.bert_model = TFBertMainLayer(config, name="bert_model")
self.bert_model.config = config
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.projection_dim = config.projection_dim
if self.projection_dim > 0:
self.encode_proj = Dense(
config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
)
def call(
self,
input_ids: Tensor,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
outputs = self.bert_model(
inputs=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output, pooled_output = outputs[:2]
pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output)
if not return_dict:
return (sequence_output, pooled_output) + outputs[2:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@property
def embeddings_size(self) -> int:
if self.projection_dim > 0:
return self.projection_dim
return self.bert_model.config.hidden_size
class TFDPRSpanPredictor(TFPreTrainedModel):
base_model_prefix = "encoder"
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.encoder = TFDPREncoder(config, name="encoder")
self.qa_outputs = Dense(2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs")
self.qa_classifier = Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
)
def call(
self,
input_ids: Tensor,
attention_mask: Tensor,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
# feed encoder
outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
# compute logits
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
# resize
start_logits = tf.reshape(start_logits, [n_passages, sequence_length])
end_logits = tf.reshape(end_logits, [n_passages, sequence_length])
relevance_logits = tf.reshape(relevance_logits, [n_passages])
if not return_dict:
return (start_logits, end_logits, relevance_logits) + outputs[2:]
return TFDPRReaderOutput(
start_logits=start_logits,
end_logits=end_logits,
relevance_logits=relevance_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
##################
# PreTrainedModel
##################
class TFDPRPretrainedContextEncoder(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DPRConfig
base_model_prefix = "ctx_encoder"
class TFDPRPretrainedQuestionEncoder(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DPRConfig
base_model_prefix = "question_encoder"
class TFDPRPretrainedReader(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DPRConfig
base_model_prefix = "reader"
###############
# Actual Models
###############
TF_DPR_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Tensorflow `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__
subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
general usage and behavior.
.. note::
TF 2.0 models accepts two formats as inputs: - having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments. This second option is useful
when using :meth:`tf.keras.Model.fit` method which currently requires having all the tensors in the first
argument of the model call function: :obj:`model(inputs)`. If you choose this second option, there are three
possibilities you can use to gather all the input Tensors in the first positional argument : - a single Tensor
with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)` - a list of varying length with one or
several input Tensors IN THE ORDER given in the docstring: :obj:`model([input_ids, attention_mask])` or
:obj:`model([input_ids, attention_mask, token_type_ids])` - a dictionary with one or several input Tensors
associated to the input names given in the docstring: :obj:`model({"input_ids": input_ids, "token_type_ids":
token_type_ids})`
Parameters:
config (:class:`~transformers.DPRConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
model weights.
"""
TF_DPR_ENCODERS_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs (for a pair title+text for example):
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences (for a question for example):
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
rather than the left.
Indices can be obtained using :class:`~transformers.DPRTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
TF_DPR_READER_INPUTS_DOCSTRING = r"""
Args:
input_ids: (:obj:`Numpy array` or :obj:`tf.Tensor` of shapes :obj:`(n_passages, sequence_length)`):
Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
and 2) the passages titles and 3) the passages texts To match pretraining, DPR :obj:`input_ids` sequence
should be formatted with [CLS] and [SEP] with the format:
``[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>``
DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
rather than the left.
Indices can be obtained using :class:`~transformers.DPRReaderTokenizer`. See this class documentation for
more details.
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(n_passages, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(n_passages, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to rturn the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@add_start_docstrings(
"The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
TF_DPR_START_DOCSTRING,
)
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.ctx_encoder = TFDPREncoder(config, name="ctx_encoder")
def get_input_embeddings(self):
return self.ctx_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRContextEncoderOutput, Tuple[Tensor, ...]]:
r"""
Return:
Examples::
>>> from transformers import TFDPRContextEncoder, DPRContextEncoderTokenizer
>>> tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
>>> model = TFDPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', return_dict=True, from_pt=True)
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return outputs[1:]
return TFDPRContextEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(
"The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
TF_DPR_START_DOCSTRING,
)
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.question_encoder = TFDPREncoder(config, name="question_encoder")
def get_input_embeddings(self):
return self.question_encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
r"""
Return:
Examples::
>>> from transformers import TFDPRQuestionEncoder, DPRQuestionEncoderTokenizer
>>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
>>> model = TFDPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base', return_dict=True, from_pt=True)
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return outputs[1:]
return TFDPRQuestionEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@add_start_docstrings(
"The bare DPRReader transformer outputting span predictions.",
TF_DPR_START_DOCSTRING,
)
class TFDPRReader(TFDPRPretrainedReader):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.span_predictor = TFDPRSpanPredictor(config, name="span_predictor")
def get_input_embeddings(self):
return self.span_predictor.encoder.bert_model.get_input_embeddings()
@add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
r"""
Return:
Examples::
>>> from transformers import TFDPRReader, DPRReaderTokenizer
>>> tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
>>> model = TFDPRReader.from_pretrained('facebook/dpr-reader-single-nq-base', return_dict=True, from_pt=True)
>>> encoded_inputs = tokenizer(
... questions=["What is love ?"],
... titles=["Haddaway"],
... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
... return_tensors='tf'
... )
>>> outputs = model(encoded_inputs)
>>> start_logits = outputs.start_logits
>>> end_logits = outputs.end_logits
>>> relevance_logits = outputs.relevance_logits
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
...@@ -735,6 +735,15 @@ class DistilBertPreTrainedModel: ...@@ -735,6 +735,15 @@ class DistilBertPreTrainedModel:
requires_pytorch(self) requires_pytorch(self)
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DPRContextEncoder: class DPRContextEncoder:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -495,6 +495,45 @@ class TFDistilBertPreTrainedModel: ...@@ -495,6 +495,45 @@ class TFDistilBertPreTrainedModel:
requires_tf(self) requires_tf(self)
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFDPRContextEncoder:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFDPRPretrainedContextEncoder:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFDPRPretrainedQuestionEncoder:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFDPRPretrainedReader:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFDPRQuestionEncoder:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFDPRReader:
def __init__(self, *args, **kwargs):
requires_tf(self)
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -24,6 +24,8 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention ...@@ -24,6 +24,8 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available(): if is_torch_available():
import torch
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers.modeling_dpr import ( from transformers.modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -227,3 +229,36 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -227,3 +229,36 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPRReader.from_pretrained(model_name) model = DPRReader.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch
class DPRModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", return_dict=False)
model.to(torch_device)
input_ids = torch.tensor(
[[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029, 102]], dtype=torch.long, device=torch_device
) # [CLS] hello, is my dog cute? [SEP]
output = model(input_ids)[0] # embedding shape = (1, 768)
# compare the actual values for a slice.
expected_slice = torch.tensor(
[
[
0.03236253,
0.12753335,
0.16818509,
0.00279786,
0.3896933,
0.24264945,
0.2178971,
-0.02335227,
-0.08481959,
-0.14324117,
]
],
dtype=torch.float,
device=torch_device,
)
self.assertTrue(torch.allclose(output[:, :10], expected_slice, atol=1e-4))
# coding=utf-8
# Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import numpy
import tensorflow as tf
from transformers import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
BertConfig,
DPRConfig,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
)
class TFDPRModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
projection_dim=0,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.projection_dim = projection_dim
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor(
[self.batch_size, self.seq_length], vocab_size=2
) # follow test_modeling_tf_ctrl.py
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
# MODIFY
return_dict=False,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_dpr_context_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFDPRContextEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
def create_and_check_dpr_question_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFDPRQuestionEncoder(config=config)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, return_dict=True) # MODIFY
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
def create_and_check_dpr_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFDPRReader(config=config)
result = model(input_ids, attention_mask=input_mask, return_dict=True) # MODIFY
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.relevance_logits.shape, (self.batch_size,))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@require_tf
class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
)
if is_tf_available()
else ()
)
test_resize_embeddings = False
test_missing_keys = False
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = TFDPRModelTester(self)
self.config_tester = ConfigTester(self, config_class=DPRConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_dpr_context_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
def test_dpr_question_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
def test_dpr_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
for model_name in TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRQuestionEncoder.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
for model_name in TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDPRReader.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
@require_tf
class TFDPRModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", return_dict=False)
input_ids = tf.constant(
[[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029, 102]]
) # [CLS] hello, is my dog cute? [SEP]
output = model(input_ids)[0] # embedding shape = (1, 768)
# compare the actual values for a slice.
expected_slice = tf.constant(
[
[
0.03236253,
0.12753335,
0.16818509,
0.00279786,
0.3896933,
0.24264945,
0.2178971,
-0.02335227,
-0.08481959,
-0.14324117,
]
]
)
self.assertTrue(numpy.allclose(output[:, :10].numpy(), expected_slice.numpy(), atol=1e-4))
...@@ -33,6 +33,8 @@ IGNORE_NON_TESTED = [ ...@@ -33,6 +33,8 @@ IGNORE_NON_TESTED = [
"DPRSpanPredictor", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder. "ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model. "T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix "TFRobertaForMultipleChoice", # TODO: fix
] ]
...@@ -57,6 +59,8 @@ IGNORE_NON_DOCUMENTED = [ ...@@ -57,6 +59,8 @@ IGNORE_NON_DOCUMENTED = [
"DPREncoder", # Building part of bigger (documented) model. "DPREncoder", # Building part of bigger (documented) model.
"DPRSpanPredictor", # Building part of bigger (documented) model. "DPRSpanPredictor", # Building part of bigger (documented) model.
"T5Stack", # Building part of bigger (tested) model. "T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (documented) model.
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
"TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?) "TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?)
] ]
...@@ -87,6 +91,10 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -87,6 +91,10 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagSequenceForGeneration", "RagSequenceForGeneration",
"RagTokenForGeneration", "RagTokenForGeneration",
"T5Stack", "T5Stack",
"TFDPRContextEncoder",
"TFDPREncoder",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFFunnelBaseModel", "TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment