You need to sign in or sign up before continuing.
Unverified Commit fbd87921 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

Add DPR model (#5279)



* beginning of dpr modeling

* wip

* implement forward

* remove biencoder + better init weights

* export dpr model to embed model for nlp lib

* add new api

* remove old code

* make style

* fix dumb typo

* don't load bert weights

* docs

* docs

* style

* move the `k` parameter

* fix init_weights

* add pretrained configs

* minor

* update config names

* style

* better config

* style

* clean code based on PR comments

* change Dpr to DPR

* fix config

* switch encoder config to a dict

* style

* inheritance -> composition

* add messages in assert startements

* add dpr reader tokenizer

* one tokenizer per model

* fix base_model_prefix

* fix imports

* typo

* add convert script

* docs

* change tokenizers conf names

* style

* change tokenizers conf names

* minor

* minor

* fix wrong names

* minor

* remove unused convert functions

* rename convert script

* use return_tensors in tokenizers

* remove n_questions dim

* move generate logic to tokenizer

* style

* add docs

* docs

* quality

* docs

* add tests

* style

* add tokenization tests

* DPR full tests

* Stay true to the attention mask building

* update docs

* missing param in bert input docs

* docs

* style
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent d2a93991
...@@ -121,7 +121,10 @@ conversion utilities for the following models: ...@@ -121,7 +121,10 @@ conversion utilities for the following models:
trained using `OPUS <http://opus.nlpl.eu/>`_ pretrained_models data by Jörg Tiedemann. trained using `OPUS <http://opus.nlpl.eu/>`_ pretrained_models data by Jörg Tiedemann.
21. `Longformer <https://github.com/allenai/longformer>`_ (from AllenAI) released with the paper `Longformer: The 21. `Longformer <https://github.com/allenai/longformer>`_ (from AllenAI) released with the paper `Longformer: The
Long-Document Transformer <https://arxiv.org/abs/2004.05150>`_ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Long-Document Transformer <https://arxiv.org/abs/2004.05150>`_ by Iz Beltagy, Matthew E. Peters, and Arman Cohan.
22. `Other community models <https://huggingface.co/models>`_, contributed by the `community 22. `DPR <https://github.com/facebookresearch/DPR>`_ (from Facebook) released with the paper `Dense Passage Retrieval
for Open-Domain Question Answering <https://arxiv.org/abs/2004.04906>`_ by Vladimir Karpukhin, Barlas Oğuz, Sewon
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
23. `Other community models <https://huggingface.co/models>`_, contributed by the `community
<https://huggingface.co/users>`_. <https://huggingface.co/users>`_.
.. toctree:: .. toctree::
...@@ -199,3 +202,4 @@ conversion utilities for the following models: ...@@ -199,3 +202,4 @@ conversion utilities for the following models:
model_doc/longformer model_doc/longformer
model_doc/retribert model_doc/retribert
model_doc/mobilebert model_doc/mobilebert
model_doc/dpr
DPR
----------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~
Dense Passage Retrieval (DPR) - is a set of tools and models for state-of-the-art open-domain Q&A research.
It is based on the following paper:
Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih, Dense Passage Retrieval for Open-Domain Question Answering.
The abstract from the paper is the following:
*Open-domain question answering relies on efficient passage retrieval to select candidate contexts, where traditional
sparse vector space models, such as TF-IDF or BM25, are the de facto method. In this work, we show that retrieval can
be practically implemented using dense representations alone, where embeddings are learned from a small number of
questions and passages by a simple dual-encoder framework. When evaluated on a wide range of open-domain QA datasets,
our dense retriever outperforms a strong Lucene-BM25 system largely by 9%-19% absolute in terms of top-20 passage
retrieval accuracy, and helps our end-to-end QA system establish new state-of-the-art on multiple open-domain QA
benchmarks.*
The original code can be found `here <https://github.com/facebookresearch/DPR>`_.
DPRConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRConfig
:members:
DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRContextEncoderTokenizer
:members:
DPRContextEncoderTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRContextEncoderTokenizerFast
:members:
DPRQuestionEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRQuestionEncoderTokenizer
:members:
DPRQuestionEncoderTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRQuestionEncoderTokenizerFast
:members:
DPRReaderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRReaderTokenizer
:members:
DPRReaderTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRReaderTokenizerFast
:members:
DPRContextEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRContextEncoder
:members:
DPRQuestionEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRQuestionEncoder
:members:
DPRReader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DPRReader
:members:
...@@ -27,6 +27,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig ...@@ -27,6 +27,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
...@@ -129,6 +130,14 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize ...@@ -129,6 +130,14 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_dpr import (
DPRContextEncoderTokenizer,
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizer,
DPRQuestionEncoderTokenizerFast,
DPRReaderTokenizer,
DPRReaderTokenizerFast,
)
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
...@@ -382,6 +391,14 @@ if is_torch_available(): ...@@ -382,6 +391,14 @@ if is_torch_available():
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_dpr import (
DPRPretrainedContextEncoder,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
)
from .modeling_retribert import ( from .modeling_retribert import (
RetriBertPreTrainedModel, RetriBertPreTrainedModel,
RetriBertModel, RetriBertModel,
......
# coding=utf-8
# Copyright 2010, DPR authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DPR model configuration """
import logging
from .configuration_bert import BertConfig
logger = logging.getLogger(__name__)
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json",
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-single-nq-base/config.json",
}
class DPRConfig(BertConfig):
r"""
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
`DPRModel`.
This is the configuration class to store the configuration of a `DPRContextEncoder`, `DPRQuestionEncoder`, or a `DPRReader`.
It is used to instantiate the components of the DPR model.
Args:
projection_dim (:obj:`int`, optional, defaults to 0):
Dimension of the projection for the context and question encoders.
If it is set to zero (default), then no projection is done.
"""
model_type = "dpr"
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection
super().__init__(**kwargs)
self.projection_dim = projection_dim
import argparse
import collections
from pathlib import Path
import torch
from torch.serialization import default_restore_location
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
CheckpointState = collections.namedtuple(
"CheckpointState", ["model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params"]
)
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
print("Reading saved model from %s", model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu"))
return CheckpointState(**state_dict)
class DPRState:
def __init__(self, src_file: Path):
self.src_file = src_file
def load_dpr_model(self):
raise NotImplementedError
@staticmethod
def from_type(comp_type: str, *args, **kwargs) -> "DPRState":
if comp_type.startswith("c"):
return DPRContextEncoderState(*args, **kwargs)
if comp_type.startswith("q"):
return DPRQuestionEncoderState(*args, **kwargs)
if comp_type.startswith("r"):
return DPRReaderState(*args, **kwargs)
else:
raise ValueError("Component type must be either 'ctx_encoder', 'question_encoder' or 'reader'.")
class DPRContextEncoderState(DPRState):
def load_dpr_model(self):
model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
print("Loading DPR biencoder from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file)
encoder, prefix = model.ctx_encoder, "ctx_model."
state_dict = {}
for key, value in saved_state.model_dict.items():
if key.startswith(prefix):
key = key[len(prefix) :]
if not key.startswith("encode_proj."):
key = "bert_model." + key
state_dict[key] = value
encoder.load_state_dict(state_dict)
return model
class DPRQuestionEncoderState(DPRState):
def load_dpr_model(self):
model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
print("Loading DPR biencoder from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file)
encoder, prefix = model.question_encoder, "question_model."
state_dict = {}
for key, value in saved_state.model_dict.items():
if key.startswith(prefix):
key = key[len(prefix) :]
if not key.startswith("encode_proj."):
key = "bert_model." + key
state_dict[key] = value
encoder.load_state_dict(state_dict)
return model
class DPRReaderState(DPRState):
def load_dpr_model(self):
model = DPRReader(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
print("Loading DPR reader from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file)
state_dict = {}
for key, value in saved_state.model_dict.items():
if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"):
key = "encoder.bert_model." + key[len("encoder.") :]
state_dict[key] = value
model.span_predictor.load_state_dict(state_dict)
return model
def convert(comp_type: str, src_file: Path, dest_dir: Path):
dest_dir = Path(dest_dir)
dest_dir.mkdir(exist_ok=True)
dpr_state = DPRState.from_type(comp_type, src_file=src_file)
model = dpr_state.load_dpr_model()
model.save_pretrained(dest_dir)
model.from_pretrained(dest_dir) # sanity check
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--type", type=str, help="Type of the component to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
)
parser.add_argument(
"--src",
type=str,
help="Path to the dpr checkpoint file. They can be downloaded from the official DPR repo https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the 'retriever' checkpoints.",
)
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.")
args = parser.parse_args()
src_file = Path(args.src)
dest_dir = f"converted-{src_file.name}" if args.dest is None else args.dest
dest_dir = Path(dest_dir)
assert src_file.exists()
assert (
args.type is not None
), "Please specify the component type of the DPR model to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
convert(args.type, src_file, dest_dir)
...@@ -617,6 +617,8 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -617,6 +617,8 @@ BERT_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states tensors of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
""" """
......
This diff is collapsed.
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for DPR."""
import collections
import logging
from typing import List, Optional, Union
from .file_utils import add_end_docstrings, add_start_docstrings
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_utils_base import BatchEncoding, TensorType
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}
}
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}
}
READER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}
}
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-ctx_encoder-single-nq-base": 512,
}
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-question_encoder-single-nq-base": 512,
}
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-reader-single-nq-base": 512,
}
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
}
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
}
READER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
}
class DPRContextEncoderTokenizer(BertTokenizer):
r"""
Constructs a DPRContextEncoderTokenizer.
:class:`~transformers.DPRContextEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
class DPRContextEncoderTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" DPRContextEncoderTokenizer (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.DPRContextEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
class DPRQuestionEncoderTokenizer(BertTokenizer):
r"""
Constructs a DPRQuestionEncoderTokenizer.
:class:`~transformers.DPRQuestionEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" DPRQuestionEncoderTokenizer (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.DPRQuestionEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
DPRSpanPrediction = collections.namedtuple(
"DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
)
DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
CUSTOM_DPR_READER_DOCSTRING = r"""
Return a dictionary with the token ids of the input strings and other information to give to :obj:`.decode_best_spans`.
It converts the strings of a question and different passages (title + text) in a sequence of ids (integer), using the tokenizer and vocabulary.
The resulting `input_ids` is a matrix of size :obj:`(n_passages, sequence_length)` with the format:
[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>
Inputs:
questions (:obj:`str`, :obj:`List[str]`):
The questions to be encoded.
You can specify one question for many passages. In this case, the question will be duplicated like :obj:`[questions] * n_passages`.
Otherwise you have to specify as many questions as in :obj:`titles` or :obj:`texts`.
titles (:obj:`str`, :obj:`List[str]`):
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
texts (:obj:`str`, :obj:`List[str]`):
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
Activate and control padding. Accepts the following values:
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
Activate and control truncation. Accepts the following values:
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
* `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
max_length (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
Control the length for padding/truncation. Accepts the following values
* `None` (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated.
* `any integer value` (e.g. `42`): Use this specific maximum length value if required by one of the truncation/padding parameters.
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
`What are attention masks? <../glossary.html#attention-mask>`__
Return:
A Dictionary of shape::
{
input_ids: list[list[int]],
attention_mask: list[int] if return_attention_mask is True (default)
}
With the fields:
- ``input_ids``: list of token ids to be fed to a model
- ``attention_mask``: list of indices specifying which tokens should be attended to by the model
"""
@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class CustomDPRReaderTokenizerMixin:
def __call__(
self,
questions,
titles,
texts,
padding: Union[bool, str] = True,
truncation: Union[bool, str] = True,
max_length: Optional[int] = 512,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> BatchEncoding:
titles = titles if not isinstance(titles, str) else [titles]
texts = texts if not isinstance(texts, str) else [texts]
n_passages = len(titles)
questions = questions if not isinstance(questions, str) else [questions] * n_passages
assert len(titles) == len(
texts
), "There should be as many titles than texts but got {} titles and {} texts.".format(len(titles), len(texts))
encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
encoded_inputs = {
"input_ids": [
(encoded_question_and_title + encoded_text)[:max_length]
if max_length is not None and truncation
else encoded_question_and_title + encoded_text
for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
]
}
if return_attention_mask is not False:
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
encoded_inputs["attention_mask"] = attention_mask
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
def decode_best_spans(
self,
reader_input: BatchEncoding,
reader_output: DPRReaderOutput,
num_spans: int = 16,
max_answer_length: int = 64,
num_spans_per_passage: int = 4,
) -> List[DPRSpanPrediction]:
"""
Get the span predictions for the extractive Q&A model.
Outputs: `List` of `DPRReaderOutput` sorted by descending `(relevance_score, span_score)`.
Each `DPRReaderOutput` is a `Tuple` with:
**span_score**: ``float`` that corresponds to the score given by the reader for this span compared to other spans
in the same passage. It corresponds to the sum of the start and end logits of the span.
**relevance_score**: ``float`` that corresponds to the score of the each passage to answer the question,
compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
**doc_id**: ``int``` the id of the passage.
**start_index**: ``int`` the start index of the span (inclusive).
**end_index**: ``int`` the end index of the span (inclusive).
Examples::
from transformers import DPRReader, DPRReaderTokenizer
tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
encoded_inputs = tokenizer(
questions=["What is love ?"],
titles=["Haddaway"],
texts=["'What Is Love' is a song recorded by the artist Haddaway"],
return_tensors='pt'
)
outputs = model(**encoded_inputs)
predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
print(predicted_spans[0].text) # best span
"""
input_ids = reader_input["input_ids"]
start_logits, end_logits, relevance_logits = reader_output[:3]
n_passages = len(relevance_logits)
sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
nbest_spans_predictions: List[DPRReaderOutput] = []
for doc_id in sorted_docs:
sequence_ids = list(input_ids[doc_id])
# assuming question & title information is at the beginning of the sequence
passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
if sequence_ids[-1] == self.pad_token_id:
sequence_len = sequence_ids.index(self.pad_token_id)
else:
sequence_len = len(sequence_ids)
best_spans = self._get_best_spans(
start_logits=start_logits[doc_id][passage_offset:sequence_len],
end_logits=end_logits[doc_id][passage_offset:sequence_len],
max_answer_length=max_answer_length,
top_spans=num_spans_per_passage,
)
for start_index, end_index in best_spans:
start_index += passage_offset
end_index += passage_offset
nbest_spans_predictions.append(
DPRSpanPrediction(
span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
relevance_score=relevance_logits[doc_id],
doc_id=doc_id,
start_index=start_index,
end_index=end_index,
text=self.decode(sequence_ids[start_index : end_index + 1]),
)
)
if len(nbest_spans_predictions) >= num_spans:
break
return nbest_spans_predictions[:num_spans]
def _get_best_spans(
self, start_logits: List[int], end_logits: List[int], max_answer_length: int, top_spans: int,
) -> List[DPRSpanPrediction]:
"""
Finds the best answer span for the extractive Q&A model for one passage.
It returns the best span by descending `span_score` order and keeping max `top_spans` spans.
Spans longer that `max_answer_length` are ignored.
"""
scores = []
for (start_index, start_score) in enumerate(start_logits):
for (answer_length, end_score) in enumerate(end_logits[start_index : start_index + max_answer_length]):
scores.append(((start_index, start_index + answer_length), start_score + end_score))
scores = sorted(scores, key=lambda x: x[1], reverse=True)
chosen_span_intervals = []
for (start_index, end_index), score in scores:
assert start_index <= end_index, "Wrong span indices: [{}:{}]".format(start_index, end_index)
length = end_index - start_index + 1
assert length <= max_answer_length, "Span is too long: {} > {}".format(length, max_answer_length)
if any(
[
start_index <= prev_start_index <= prev_end_index <= end_index
or prev_start_index <= start_index <= end_index <= prev_end_index
for (prev_start_index, prev_end_index) in chosen_span_intervals
]
):
continue
chosen_span_intervals.append((start_index, end_index))
if len(chosen_span_intervals) == top_spans:
break
return chosen_span_intervals
@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
r"""
Constructs a DPRReaderTokenizer.
:class:`~transformers.DPRReaderTokenizer` is alsmost identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
r"""
Constructs a DPRReaderTokenizerFast.
:class:`~transformers.DPRReaderTokenizerFast` is almost identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
...@@ -965,7 +965,7 @@ ENCODE_KWARGS_DOCSTRING = r""" ...@@ -965,7 +965,7 @@ ENCODE_KWARGS_DOCSTRING = r"""
>= 7.5 (Volta). >= 7.5 (Volta).
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`, Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers. PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
""" """
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
...@@ -1900,7 +1900,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1900,7 +1900,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`, Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers. PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Set to ``False`` to avoid printing infos and warnings. Set to ``False`` to avoid printing infos and warnings.
""" """
......
# coding=utf-8
# Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers.modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
class DPRModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
projection_dim=0,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.projection_dim = projection_dim
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_dpr_context_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRContextEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_question_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRReader(config=config)
model.to(torch_device)
model.eval()
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
result = {
"relevance_logits": relevance_logits,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@require_torch
class DPRModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (DPRContextEncoder, DPRQuestionEncoder, DPRReader,) if is_torch_available() else ()
test_resize_embeddings = False
test_missing_keys = False # why?
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = DPRModelTester(self)
self.config_tester = ConfigTester(self, config_class=DPRConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_dpr_context_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
def test_dpr_question_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
def test_dpr_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPRContextEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPRContextEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)
for model_name in DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPRQuestionEncoder.from_pretrained(model_name)
self.assertIsNotNone(model)
for model_name in DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DPRReader.from_pretrained(model_name)
self.assertIsNotNone(model)
# coding=utf-8
# Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.tokenization_dpr import (
DPRContextEncoderTokenizer,
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizer,
DPRQuestionEncoderTokenizerFast,
DPRReaderOutput,
DPRReaderTokenizer,
DPRReaderTokenizerFast,
)
from transformers.tokenization_utils_base import BatchEncoding
from .test_tokenization_bert import BertTokenizationTest
from .utils import slow
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRContextEncoderTokenizer
def get_rust_tokenizer(self, **kwargs):
return DPRContextEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRQuestionEncoderTokenizer
def get_rust_tokenizer(self, **kwargs):
return DPRQuestionEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
class DPRReaderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRReaderTokenizer
def get_rust_tokenizer(self, **kwargs):
return DPRReaderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
@slow
def test_decode_best_spans(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
text_1 = tokenizer.encode("question sequence", add_special_tokens=False)
text_2 = tokenizer.encode("title sequence", add_special_tokens=False)
text_3 = tokenizer.encode("text sequence " * 4, add_special_tokens=False)
input_ids = [[101] + text_1 + [102] + text_2 + [102] + text_3]
reader_input = BatchEncoding({"input_ids": input_ids})
start_logits = [[0] * len(input_ids[0])]
end_logits = [[0] * len(input_ids[0])]
relevance_logits = [0]
reader_output = DPRReaderOutput(start_logits, end_logits, relevance_logits)
start_index, end_index = 8, 9
start_logits[0][start_index] = 10
end_logits[0][end_index] = 10
predicted_spans = tokenizer.decode_best_spans(reader_input, reader_output)
self.assertEqual(predicted_spans[0].start_index, start_index)
self.assertEqual(predicted_spans[0].end_index, end_index)
self.assertEqual(predicted_spans[0].doc_id, 0)
@slow
def test_call(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
text_1 = tokenizer.encode("question sequence", add_special_tokens=False)
text_2 = tokenizer.encode("title sequence", add_special_tokens=False)
text_3 = tokenizer.encode("text sequence", add_special_tokens=False)
expected_input_ids = [101] + text_1 + [102] + text_2 + [102] + text_3
encoded_input = tokenizer(questions=["question sequence"], titles=["title sequence"], texts=["text sequence"])
self.assertIn("input_ids", encoded_input)
self.assertIn("attention_mask", encoded_input)
self.assertListEqual(encoded_input["input_ids"][0], expected_input_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment