Unverified Commit c754c41c authored by Ola Piktus's avatar Ola Piktus Committed by GitHub
Browse files

RAG (#6813)

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* Formatting / renaming prior to actual work

* First commit

* improve comments

* Retrieval evaluation scripts

* refactor to include modeling outputs + MPI retriever

* Fix rag-token model + refactor

* Various fixes + finetuning logic

* use_bos fix

* Retrieval refactor

* Finetuning refactoring and cleanup

* Add documentation and cleanup

* Remove set_up_rag_env.sh file

* Fix retrieval wit HF index

* Fix import errors

* Fix quality errors

* Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867



* fix quality

* Fix RAG Sequence generation

* minor cleanup plus initial tests

* fix test

* fix tests 2

* Comments fix

* post-merge fixes

* Improve readme + post-rebase refactor

* Extra dependencied for tests

* Fix tests

* Fix tests 2

* Refactor test requirements

* Fix tests 3

* Post-rebase refactor

* rename nlp->datasets

* RAG integration tests

* add tokenizer to slow integration test and allow retriever to run on cpu

* add tests; fix position ids warning

* change structure

* change structure

* add from encoder generator

* save working solution

* make all integration tests pass

* add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained

* don't save paths

* delete unnecessary imports

* pass config to AutoTokenizer.from_pretrained for Rag tokenizers

* init wiki_dpr only once

* hardcode legacy index and passages paths (todo: add the right urls)

* finalize config

* finalize retriver api and config api

* LegacyIndex index download refactor

* add dpr to autotokenizer

* make from pretrained more flexible

* fix ragfortokengeneration

* small name changes in tokenizer

* add labels to models

* change default index name

* add retrieval tests

* finish token generate

* align test with previous version and make all tests pass

* add tests

* finalize tests

* implement thoms suggestions

* add first version of test

* make first tests work

* make retriever platform agnostic

* naming

* style

* add legacy index URL

* docstrings + simple retrieval test for distributed

* clean model api

* add doc_ids to retriever's outputs

* fix retrieval tests

* finish model outputs

* finalize model api

* fix generate problem for rag

* fix generate for other modles

* fix some tests

* save intermediate

* set generate to default

* big refactor generate

* delete rag_api

* correct pip faiss install

* fix auto tokenization test

* fix faiss install

* fix test

* move the distributed logic to examples

* model page

* docs

* finish tests

* fix dependencies

* fix import in __init__

* Refactor eval_rag and finetune scripts

* start docstring

* add psutil to test

* fix tf test

* move require torch to top

* fix retrieval test

* align naming

* finish automodel

* fix repo consistency

* test ragtokenizer save/load

* add rag model output docs

* fix ragtokenizer save/load from pretrained

* fix tokenizer dir

* remove torch in retrieval

* fix docs

* fixe finetune scripts

* finish model docs

* finish docs

* remove auto model for now

* add require torch

* remove solved todos

* integrate sylvains suggestions

* sams comments

* correct mistake on purpose

* improve README

* Add generation test cases

* fix rag token

* clean token generate

* fix test

* add note to test

* fix attention mask

* add t5 test for rag

* Fix handling prefix in finetune.py

* don't overwrite index_name
Co-authored-by: default avatarPatrick Lewis <plewis@fb.com>
Co-authored-by: default avatarAleksandra Piktus <piktus@devfair0141.h2.fair>
Co-authored-by: default avatarAleksandra Piktus <piktus@learnfair5102.h2.fair>
Co-authored-by: default avatarAleksandra Piktus <piktus@learnfair5067.h2.fair>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 1ee2194f
......@@ -89,7 +89,8 @@ extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil", "parameterized"]
extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"]
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
extras["quality"] = ["black >= 20.8b1", "isort >= 5", "flake8 >= 3.8.3"]
......
......@@ -42,6 +42,7 @@ from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
......@@ -86,6 +87,7 @@ from .file_utils import (
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
is_psutil_available,
is_py3nvml_available,
is_tf_available,
......@@ -140,6 +142,9 @@ from .pipelines import (
pipeline,
)
# Retriever
from .retrieval_rag import RagRetriever
# Tokenizers
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
......@@ -172,6 +177,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
......@@ -416,6 +422,7 @@ if is_torch_available():
load_tf_weights_in_openai_gpt,
)
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
......
......@@ -24,6 +24,7 @@ from .configuration_bert_generation import BertGenerationConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
......@@ -38,6 +39,7 @@ from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfi
from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_rag import RagConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
......@@ -75,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
......@@ -110,7 +113,9 @@ CONFIG_MAPPING = OrderedDict(
("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig),
("lxmert", LxmertConfig),
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig),
("rag", RagConfig),
]
)
......@@ -145,6 +150,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
]
)
......
......@@ -14,7 +14,7 @@
# limitations under the License.
""" DPR model configuration """
from .configuration_bert import BertConfig
from .configuration_utils import PretrainedConfig
from .utils import logging
......@@ -27,7 +27,7 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class DPRConfig(BertConfig):
class DPRConfig(PretrainedConfig):
r"""
:class:`~transformers.DPRConfig` is the configuration class to store the configuration of a
`DPRModel`.
......@@ -36,12 +36,73 @@ class DPRConfig(BertConfig):
It is used to instantiate the components of the DPR model.
Args:
projection_dim (:obj:`int`, optional, defaults to 0):
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the DPR model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
projection_dim (:obj:`int`, `optional`, defaults to 0):
Dimension of the projection for the context and question encoders.
If it is set to zero (default), then no projection is done.
"""
model_type = "dpr"
def __init__(self, projection_dim: int = 0, **kwargs): # projection of the encoders, 0 for no projection
super().__init__(**kwargs)
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
projection_dim: int = 0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RAG model configuration """
import copy
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings
RAG_CONFIG_DOC = r"""
:class:`~transformers.RagConfig` stores the configuration of a `RagModel`.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
title_sep (:obj:`str`, `optional`, defaults to ``" / "``):
Separator inserted between the title and the text of the retrieved document when calling :class:`~transformers.RagRetriever`.
doc_sep (:obj:`str`, `optional`, defaults to ``" // "``):
Separator inserted between the the text of the retrieved document and the original input when calliang :class:`~transformers.RagRetriever`.
n_docs (:obj:`int`, `optional`, defaults to 5):
Number of documents to retrieve.
max_combined_length (:obj:`int`, `optional`, defaults to 300):
Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`.
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`.
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`).
dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`)
Which split of the ``dataset`` to load.
index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`)
The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`.
index_path (:obj:`str`, `optional`)
The path to the serialized faiss index on disk.
passages_path: (:obj:`str`, `optional`):
A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex`
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation.
If set to ``0.0``, no label smoothing is performed.
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the logits are marginalized over all documents
by making use of ``torch.nn.functional.log_softmax``.
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation.
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
Controls whether we want to deduplicate the generations from different context documents for a given input.
Has to be set to :obj:`False` if used while training with distributed backend.
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the score of the BOS token is disregarded when computing
the loss.
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail.
"""
@add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig):
model_type = "rag"
def __init__(
self,
vocab_size=None,
is_encoder_decoder=True,
prefix=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
decoder_start_token_id=None,
title_sep=" / ",
doc_sep=" // ",
n_docs=5,
max_combined_length=300,
retrieval_vector_size=768,
retrieval_batch_size=8,
dataset="wiki_dpr",
dataset_split="train",
index_name="compressed",
index_path=None,
passages_path=None,
use_dummy_dataset=False,
reduce_loss=False,
label_smoothing=0.0,
do_deduplication=True,
exclude_bos_score=False,
do_marginalize=False,
output_retrieved=False,
**kwargs
):
super().__init__(
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
is_encoder_decoder=is_encoder_decoder,
prefix=prefix,
vocab_size=vocab_size,
**kwargs,
)
assert (
"question_encoder" in kwargs and "generator" in kwargs
), "Config has to be initialized with question_encoder and generator config"
question_encoder_config = kwargs.pop("question_encoder")
question_encoder_model_type = question_encoder_config.pop("model_type")
decoder_config = kwargs.pop("generator")
decoder_model_type = decoder_config.pop("model_type")
from .configuration_auto import AutoConfig
self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.reduce_loss = reduce_loss
self.label_smoothing = label_smoothing
self.exclude_bos_score = exclude_bos_score
self.do_marginalize = do_marginalize
self.title_sep = title_sep
self.doc_sep = doc_sep
self.n_docs = n_docs
self.max_combined_length = max_combined_length
self.dataset = dataset
self.dataset_split = dataset_split
self.index_name = index_name
self.retrieval_vector_size = retrieval_vector_size
self.retrieval_batch_size = retrieval_batch_size
self.passages_path = passages_path
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self.output_retrieved = output_retrieved
self.do_deduplication = do_deduplication
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default :meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -69,6 +69,7 @@ try:
import datasets # noqa: F401
_datasets_available = True
logger.debug(f"Succesfully imported datasets version {datasets.__version__}")
except ImportError:
_datasets_available = False
......@@ -119,6 +120,16 @@ try:
except ImportError:
_has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Succesfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
......@@ -175,6 +186,10 @@ def is_apex_available():
return _has_apex
def is_faiss_available():
return _faiss_available
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
......
......@@ -27,6 +27,7 @@ from .configuration_auto import (
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
......@@ -97,6 +98,7 @@ from .modeling_distilbert import (
DistilBertForTokenClassification,
DistilBertModel,
)
from .modeling_dpr import DPRQuestionEncoder
from .modeling_electra import (
ElectraForMaskedLM,
ElectraForMultipleChoice,
......@@ -148,6 +150,11 @@ from .modeling_mobilebert import (
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
......@@ -224,6 +231,7 @@ MODEL_MAPPING = OrderedDict(
(FunnelConfig, FunnelModel),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DPRConfig, DPRQuestionEncoder),
]
)
......@@ -412,7 +420,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
]
)
AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
The model class to instantiate is selected based on the :obj:`model_type` property of the config object
......
......@@ -272,6 +272,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.ctx_encoder.init_weights()
......@@ -285,6 +286,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.question_encoder.init_weights()
......@@ -298,6 +300,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"]
def init_weights(self):
self.span_predictor.encoder.init_weights()
......
This diff is collapsed.
This diff is collapsed.
......@@ -10,7 +10,7 @@ from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
......@@ -161,6 +161,21 @@ def require_torch_and_cuda(test_case):
return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not _datasets_available:
test_case = unittest.skip("test requires Datasets")(test_case)
return test_case
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not _faiss_available:
test_case = unittest.skip("test requires Faiss")(test_case)
return test_case
def get_tests_dir():
"""
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
......
......@@ -26,6 +26,7 @@ from .configuration_auto import (
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
......@@ -40,6 +41,7 @@ from .configuration_auto import (
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
RagConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
......@@ -60,6 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_fsmt import FSMTTokenizer
......@@ -74,6 +77,7 @@ from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFas
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
......@@ -110,6 +114,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
......@@ -121,6 +126,7 @@ TOKENIZER_MAPPING = OrderedDict(
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(LayoutLMConfig, (LayoutLMTokenizer, None)),
(RagConfig, (RagTokenizer, None)),
]
)
......
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RAG."""
import os
from typing import List, Optional
from .configuration_rag import RagConfig
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
class RagTokenizer:
def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder
self.generator = generator
def save_pretrained(self, save_directory):
if os.path.isfile(save_directory):
raise ValueError("Provided path ({}) should be a directory, not a file".format(save_directory))
os.makedirs(save_directory, exist_ok=True)
question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
generator_path = os.path.join(save_directory, "generator_tokenizer")
self.question_encoder.save_pretrained(question_encoder_path)
self.generator.save_pretrained(generator_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# dynamically import AutoTokenizer
from .tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None)
if config is None:
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "np",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.RagModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts
The full set of keys ``[input_ids, attention_mask, labels]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.question_encoder.model_max_length
model_inputs: BatchEncoding = self.question_encoder(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = self.generator.model_max_length
labels = self.generator(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
This diff is collapsed.
......@@ -35,28 +35,53 @@ if is_torch_available():
class T5ModelTester:
def __init__(self, parent):
def __init__(
self,
parent,
vocab_size=99,
n_positions=14,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
# For common tests
seq_length=7,
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = 13
self.encoder_seq_length = 7
self.decoder_seq_length = 9
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = True
self.use_attention_mask = True
self.use_labels = True
self.vocab_size = 99
self.n_positions = 14
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.d_ff = 37
self.relative_attention_num_buckets = 8
self.dropout_rate = 0.1
self.initializer_factor = 0.002
self.eos_token_id = 1
self.pad_token_id = 0
self.decoder_start_token_id = 0
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_positions = n_positions
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
def prepare_config_and_inputs(self):
......
import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
from datasets import Dataset
import faiss
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig
from transformers.retrieval_rag import RagRetriever
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
@require_faiss
@require_datasets
class RagRetrieverTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_dummy_hf_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
passages_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
@require_torch
def test_hf_index_retriever_call(self):
import torch
n_docs = 1
retriever = self.get_dummy_hf_index_retriever()
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
context_input_ids, context_attention_mask, retrieved_doc_embeds = (
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, list)
self.assertIsInstance(context_attention_mask, list)
self.assertIsInstance(retrieved_doc_embeds, np.ndarray)
out = retriever(
question_input_ids,
hidden_states,
prefix=retriever.config.generator.prefix,
n_docs=n_docs,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841
out["context_input_ids"],
out["context_attention_mask"],
out["retrieved_doc_embeds"],
out["doc_ids"],
)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertIsInstance(context_input_ids, torch.Tensor)
self.assertIsInstance(context_attention_mask, torch.Tensor)
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
import json
import os
import shutil
import tempfile
from unittest import TestCase
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_datasets, require_faiss, require_torch
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
if is_torch_available() and is_datasets_available() and is_faiss_available():
from transformers.configuration_rag import RagConfig
from transformers.tokenization_rag import RagTokenizer
@require_faiss
@require_datasets
@require_torch
class RagTokenizerTest(TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
self.retrieval_vector_size = 8
# DPR tok
vocab_tokens = [
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
]
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
os.makedirs(dpr_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
# BART tok
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
os.makedirs(bart_tokenizer_path, exist_ok=True)
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_with_saved_config(self):
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
rag_config.save_pretrained(save_dir)
rag_tokenizer.save_pretrained(save_dir)
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment