Unverified Commit c754c41c authored by Ola Piktus's avatar Ola Piktus Committed by GitHub
Browse files

RAG (#6813)

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* Formatting / renaming prior to actual work

* First commit

* improve comments

* Retrieval evaluation scripts

* refactor to include modeling outputs + MPI retriever

* Fix rag-token model + refactor

* Various fixes + finetuning logic

* use_bos fix

* Retrieval refactor

* Finetuning refactoring and cleanup

* Add documentation and cleanup

* Remove set_up_rag_env.sh file

* Fix retrieval wit HF index

* Fix import errors

* Fix quality errors

* Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867



* fix quality

* Fix RAG Sequence generation

* minor cleanup plus initial tests

* fix test

* fix tests 2

* Comments fix

* post-merge fixes

* Improve readme + post-rebase refactor

* Extra dependencied for tests

* Fix tests

* Fix tests 2

* Refactor test requirements

* Fix tests 3

* Post-rebase refactor

* rename nlp->datasets

* RAG integration tests

* add tokenizer to slow integration test and allow retriever to run on cpu

* add tests; fix position ids warning

* change structure

* change structure

* add from encoder generator

* save working solution

* make all integration tests pass

* add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained

* don't save paths

* delete unnecessary imports

* pass config to AutoTokenizer.from_pretrained for Rag tokenizers

* init wiki_dpr only once

* hardcode legacy index and passages paths (todo: add the right urls)

* finalize config

* finalize retriver api and config api

* LegacyIndex index download refactor

* add dpr to autotokenizer

* make from pretrained more flexible

* fix ragfortokengeneration

* small name changes in tokenizer

* add labels to models

* change default index name

* add retrieval tests

* finish token generate

* align test with previous version and make all tests pass

* add tests

* finalize tests

* implement thoms suggestions

* add first version of test

* make first tests work

* make retriever platform agnostic

* naming

* style

* add legacy index URL

* docstrings + simple retrieval test for distributed

* clean model api

* add doc_ids to retriever's outputs

* fix retrieval tests

* finish model outputs

* finalize model api

* fix generate problem for rag

* fix generate for other modles

* fix some tests

* save intermediate

* set generate to default

* big refactor generate

* delete rag_api

* correct pip faiss install

* fix auto tokenization test

* fix faiss install

* fix test

* move the distributed logic to examples

* model page

* docs

* finish tests

* fix dependencies

* fix import in __init__

* Refactor eval_rag and finetune scripts

* start docstring

* add psutil to test

* fix tf test

* move require torch to top

* fix retrieval test

* align naming

* finish automodel

* fix repo consistency

* test ragtokenizer save/load

* add rag model output docs

* fix ragtokenizer save/load from pretrained

* fix tokenizer dir

* remove torch in retrieval

* fix docs

* fixe finetune scripts

* finish model docs

* finish docs

* remove auto model for now

* add require torch

* remove solved todos

* integrate sylvains suggestions

* sams comments

* correct mistake on purpose

* improve README

* Add generation test cases

* fix rag token

* clean token generate

* fix test

* add note to test

* fix attention mask

* add t5 test for rag

* Fix handling prefix in finetune.py

* don't overwrite index_name
Co-authored-by: default avatarPatrick Lewis <plewis@fb.com>
Co-authored-by: default avatarAleksandra Piktus <piktus@devfair0141.h2.fair>
Co-authored-by: default avatarAleksandra Piktus <piktus@learnfair5102.h2.fair>
Co-authored-by: default avatarAleksandra Piktus <piktus@learnfair5067.h2.fair>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 1ee2194f
...@@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"] ...@@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"] extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil", "parameterized"] extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"]
# sphinx-rtd-theme==0.5.0 introduced big changes in the style. # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"] extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"] extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"]
......
...@@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig ...@@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
...@@ -86,6 +87,7 @@ from .file_utils import ( ...@@ -86,6 +87,7 @@ from .file_utils import (
cached_path, cached_path,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_faiss_available,
is_psutil_available, is_psutil_available,
is_py3nvml_available, is_py3nvml_available,
is_tf_available, is_tf_available,
...@@ -140,6 +142,9 @@ from .pipelines import ( ...@@ -140,6 +142,9 @@ from .pipelines import (
pipeline, pipeline,
) )
# Retriever
from .retrieval_rag import RagRetriever
# Tokenizers # Tokenizers
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
...@@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas ...@@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
...@@ -416,6 +422,7 @@ if is_torch_available(): ...@@ -416,6 +422,7 @@ if is_torch_available():
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
) )
from .modeling_pegasus import PegasusForConditionalGeneration from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_reformer import ( from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention, ReformerAttention,
......
...@@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig ...@@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
...@@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi ...@@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi
from .configuration_mobilebert import MobileBertConfig from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import ReformerConfig from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
...@@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ...@@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
...@@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict( ...@@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict(
("encoder-decoder", EncoderDecoderConfig), ("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig), ("funnel", FunnelConfig),
("lxmert", LxmertConfig), ("lxmert", LxmertConfig),
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig), ("layoutlm", LayoutLMConfig),
("rag", RagConfig),
] ]
) )
...@@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("funnel", "Funnel Transformer"), ("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"), ("lxmert", "LXMERT"),
("layoutlm", "LayoutLM"), ("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
] ]
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" DPR model configuration """ """ DPR model configuration """
from .configuration_bert import BertConfig from .configuration_utils import PretrainedConfig
from .utils import logging from .utils import logging
...@@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class DPRConfig(BertConfig): class DPRConfig(PretrainedConfig):
r""" r"""
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a :class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
`DPRModel`. `DPRModel`.
...@@ -36,12 +36,73 @@ class DPRConfig(BertConfig): ...@@ -36,12 +36,73 @@ class DPRConfig(BertConfig):
It is used to instantiate the components of the DPR model. It is used to instantiate the components of the DPR model.
Args: Args:
projection_dim (:obj:`int`, optional, defaults to 0): vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the DPR model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
projection_dim (:obj:`int`, `optional`, defaults to 0):
Dimension of the projection for the context and question encoders. Dimension of the projection for the context and question encoders.
If it is set to zero (default), then no projection is done. If it is set to zero (default), then no projection is done.
""" """
model_type = "dpr" model_type = "dpr"
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection def __init__(
super().__init__(**kwargs) self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
projection_dim: int = 0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim self.projection_dim = projection_dim
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RAG model configuration """
import copy
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings
RAG_CONFIG_DOC = r"""
:class:`~transformers.RagConfig` stores the configuration of a `RagModel`.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
title_sep (:obj:`str`, `optional`, defaults to ``" / "``):
Separator inserted between the title and the text of the retrieved document when calling :class:`~transformers.RagRetriever`.
doc_sep (:obj:`str`, `optional`, defaults to ``" // "``):
Separator inserted between the the text of the retrieved document and the original input when calliang :class:`~transformers.RagRetriever`.
n_docs (:obj:`int`, `optional`, defaults to 5):
Number of documents to retrieve.
max_combined_length (:obj:`int`, `optional`, defaults to 300):
Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`.
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`.
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`).
dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`)
Which split of the ``dataset`` to load.
index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`)
The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`.
index_path (:obj:`str`, `optional`)
The path to the serialized faiss index on disk.
passages_path: (:obj:`str`, `optional`):
A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex`
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation.
If set to ``0.0``, no label smoothing is performed.
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the logits are marginalized over all documents
by making use of ``torch.nn.functional.log_softmax``.
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation.
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
Controls whether we want to deduplicate the generations from different context documents for a given input.
Has to be set to :obj:`False` if used while training with distributed backend.
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the score of the BOS token is disregarded when computing
the loss.
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail.
"""
@add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig):
model_type = "rag"
def __init__(
self,
vocab_size=None,
is_encoder_decoder=True,
prefix=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
decoder_start_token_id=None,
title_sep=" / ",
doc_sep=" // ",
n_docs=5,
max_combined_length=300,
retrieval_vector_size=768,
retrieval_batch_size=8,
dataset="wiki_dpr",
dataset_split="train",
index_name="compressed",
index_path=None,
passages_path=None,
use_dummy_dataset=False,
reduce_loss=False,
label_smoothing=0.0,
do_deduplication=True,
exclude_bos_score=False,
do_marginalize=False,
output_retrieved=False,
**kwargs
):
super().__init__(
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
is_encoder_decoder=is_encoder_decoder,
prefix=prefix,
vocab_size=vocab_size,
**kwargs,
)
assert (
"question_encoder" in kwargs and "generator" in kwargs
), "Config has to be initialized with question_encoder and generator config"
question_encoder_config = kwargs.pop("question_encoder")
question_encoder_model_type = question_encoder_config.pop("model_type")
decoder_config = kwargs.pop("generator")
decoder_model_type = decoder_config.pop("model_type")
from .configuration_auto import AutoConfig
self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.reduce_loss = reduce_loss
self.label_smoothing = label_smoothing
self.exclude_bos_score = exclude_bos_score
self.do_marginalize = do_marginalize
self.title_sep = title_sep
self.doc_sep = doc_sep
self.n_docs = n_docs
self.max_combined_length = max_combined_length
self.dataset = dataset
self.dataset_split = dataset_split
self.index_name = index_name
self.retrieval_vector_size = retrieval_vector_size
self.retrieval_batch_size = retrieval_batch_size
self.passages_path = passages_path
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self.output_retrieved = output_retrieved
self.do_deduplication = do_deduplication
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default :meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -69,6 +69,7 @@ try: ...@@ -69,6 +69,7 @@ try:
import datasets # noqa: F401 import datasets # noqa: F401
_datasets_available = True _datasets_available = True
logger.debug(f"Succesfully imported datasets version {datasets.__version__}")
except ImportError: except ImportError:
_datasets_available = False _datasets_available = False
...@@ -119,6 +120,16 @@ try: ...@@ -119,6 +120,16 @@ try:
except ImportError: except ImportError:
_has_apex = False _has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Succesfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers") default_cache_path = os.path.join(torch_cache_home, "transformers")
...@@ -175,6 +186,10 @@ def is_apex_available(): ...@@ -175,6 +186,10 @@ def is_apex_available():
return _has_apex return _has_apex
def is_faiss_available():
return _faiss_available
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
......
...@@ -27,6 +27,7 @@ from .configuration_auto import ( ...@@ -27,6 +27,7 @@ from .configuration_auto import (
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DistilBertConfig, DistilBertConfig,
DPRConfig,
ElectraConfig, ElectraConfig,
EncoderDecoderConfig, EncoderDecoderConfig,
FlaubertConfig, FlaubertConfig,
...@@ -97,6 +98,7 @@ from .modeling_distilbert import ( ...@@ -97,6 +98,7 @@ from .modeling_distilbert import (
DistilBertForTokenClassification, DistilBertForTokenClassification,
DistilBertModel, DistilBertModel,
) )
from .modeling_dpr import DPRQuestionEncoder
from .modeling_electra import ( from .modeling_electra import (
ElectraForMaskedLM, ElectraForMaskedLM,
ElectraForMultipleChoice, ElectraForMultipleChoice,
...@@ -148,6 +150,11 @@ from .modeling_mobilebert import ( ...@@ -148,6 +150,11 @@ from .modeling_mobilebert import (
) )
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from .modeling_reformer import ( from .modeling_reformer import (
ReformerForMaskedLM, ReformerForMaskedLM,
ReformerForQuestionAnswering, ReformerForQuestionAnswering,
...@@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict(
(FunnelConfig, FunnelModel), (FunnelConfig, FunnelModel),
(LxmertConfig, LxmertModel), (LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder), (BertGenerationConfig, BertGenerationEncoder),
(DPRConfig, DPRQuestionEncoder),
] ]
) )
...@@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
] ]
) )
AUTO_MODEL_PRETRAINED_DOCSTRING = r""" AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
The model class to instantiate is selected based on the :obj:`model_type` property of the config object The model class to instantiate is selected based on the :obj:`model_type` property of the config object
......
...@@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel): ...@@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "ctx_encoder" base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.ctx_encoder.init_weights() self.ctx_encoder.init_weights()
...@@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "question_encoder" base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.question_encoder.init_weights() self.question_encoder.init_weights()
...@@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel): ...@@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "span_predictor" base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.span_predictor.encoder.init_weights() self.span_predictor.encoder.init_weights()
......
This diff is collapsed.
This diff is collapsed.
...@@ -10,7 +10,7 @@ from distutils.util import strtobool ...@@ -10,7 +10,7 @@ from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from .file_utils import _tf_available, _torch_available, _torch_tpu_available from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
...@@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case): ...@@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case):
return test_case return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not _datasets_available:
test_case = unittest.skip("test requires Datasets")(test_case)
return test_case
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not _faiss_available:
test_case = unittest.skip("test requires Faiss")(test_case)
return test_case
def get_tests_dir(): def get_tests_dir():
""" """
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
......
...@@ -26,6 +26,7 @@ from .configuration_auto import ( ...@@ -26,6 +26,7 @@ from .configuration_auto import (
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DistilBertConfig, DistilBertConfig,
DPRConfig,
ElectraConfig, ElectraConfig,
EncoderDecoderConfig, EncoderDecoderConfig,
FlaubertConfig, FlaubertConfig,
...@@ -40,6 +41,7 @@ from .configuration_auto import ( ...@@ -40,6 +41,7 @@ from .configuration_auto import (
MobileBertConfig, MobileBertConfig,
OpenAIGPTConfig, OpenAIGPTConfig,
PegasusConfig, PegasusConfig,
RagConfig,
ReformerConfig, ReformerConfig,
RetriBertConfig, RetriBertConfig,
RobertaConfig, RobertaConfig,
...@@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer ...@@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_fsmt import FSMTTokenizer from .tokenization_fsmt import FSMTTokenizer
...@@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas ...@@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
...@@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)), (FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)), (LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)), (LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)), (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
...@@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FSMTConfig, (FSMTTokenizer, None)), (FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)),
(LayoutLMConfig, (LayoutLMTokenizer, None)), (LayoutLMConfig, (LayoutLMTokenizer, None)),
(RagConfig, (RagTokenizer, None)),
] ]
) )
......
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RAG."""
import os
from typing import List, Optional
from .configuration_rag import RagConfig
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
class RagTokenizer:
def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder
self.generator = generator
def save_pretrained(self, save_directory):
if os.path.isfile(save_directory):
raise ValueError("Provided path ({}) should be a directory, not a file".format(save_directory))
os.makedirs(save_directory, exist_ok=True)
question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
generator_path = os.path.join(save_directory, "generator_tokenizer")
self.question_encoder.save_pretrained(question_encoder_path)
self.generator.save_pretrained(generator_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# dynamically import AutoTokenizer
from .tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None)
if config is None:
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "np",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.RagModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts
The full set of keys ``[input_ids, attention_mask, labels]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.question_encoder.model_max_length
model_inputs: BatchEncoding = self.question_encoder(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = self.generator.model_max_length
labels = self.generator(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
This diff is collapsed.
...@@ -35,28 +35,53 @@ if is_torch_available(): ...@@ -35,28 +35,53 @@ if is_torch_available():
class T5ModelTester: class T5ModelTester:
def __init__(self, parent): def __init__(
self,
parent,
vocab_size=99,
n_positions=14,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
# For common tests
seq_length=7,
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent self.parent = parent
self.batch_size = 13 self.batch_size = batch_size
self.encoder_seq_length = 7 self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = 9 self.decoder_seq_length = decoder_seq_length
# For common tests # For common tests
self.seq_length = self.decoder_seq_length self.seq_length = self.decoder_seq_length
self.is_training = True self.is_training = is_training
self.use_attention_mask = True self.use_attention_mask = use_attention_mask
self.use_labels = True self.use_labels = use_labels
self.vocab_size = 99 self.vocab_size = vocab_size
self.n_positions = 14 self.n_positions = n_positions
self.hidden_size = 32 self.hidden_size = hidden_size
self.num_hidden_layers = 5 self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = 4 self.num_attention_heads = num_attention_heads
self.d_ff = 37 self.d_ff = d_ff
self.relative_attention_num_buckets = 8 self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = 0.1 self.dropout_rate = dropout_rate
self.initializer_factor = 0.002 self.initializer_factor = initializer_factor
self.eos_token_id = 1 self.eos_token_id = eos_token_id
self.pad_token_id = 0 self.pad_token_id = pad_token_id
self.decoder_start_token_id = 0 self.decoder_start_token_id = decoder_start_token_id
self.scope = None self.scope = None
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
......
import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
from datasets import Dataset
import faiss
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig
from transformers.retrieval_rag import RagRetriever
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
@require_faiss
@require_datasets
class RagRetrieverTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_dummy_hf_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
passages_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
@require_torch
def test_hf_index_retriever_call(self):
import torch
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, list)
self.assertIsInstance(context_attention_mask, list)
self.assertIsInstance(retrieved_doc_embeds, np.ndarray)
out = retriever(
question_input_ids,
hidden_states,
prefix=retriever.config.generator.prefix,
n_docs=n_docs,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
out["doc_ids"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, torch.Tensor)
self.assertIsInstance(context_attention_mask, torch.Tensor)
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
import json
import os
import shutil
import tempfile
from unittest import TestCase
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
if is_torch_available() and is_datasets_available() and is_faiss_available():
from transformers.configuration_rag import RagConfig
from transformers.tokenization_rag import RagTokenizer
@require_faiss
@require_datasets
@require_torch
class RagTokenizerTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_with_saved_config(self):
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
rag_config.save_pretrained(save_dir)
rag_tokenizer.save_pretrained(save_dir)
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment