Unverified Commit 49c52025 authored by Yacine Jernite's avatar Yacine Jernite Committed by GitHub
Browse files

Eli5 examples (#4968)



* add eli5 examples

* add dense query script

* query_di

* merging

* merging

* add_utils

* adds nearest neighbor wikipedia

* batch queries

* training_retriever

* new notebooks

* moved retriever traiing script

* finished wiki40b

* max_len_fix

* train_s2s

* retriever_batch_checkpointing

* cleanup

* merge

* dim_fix

* fix_indexer

* fix_wiki40b_snippets

* fix_embed_for_r

* fp32 index

* fix_sparse_q

* joint_training

* remove obsolete datasets

* add_passage_nn_results

* add_passage_nn_results

* add_batch_nn

* add_batch_nn

* add_data_scripts

* notebook

* notebook

* notebook

* fix_multi_gpu

* add_app

* full_caching

* full_caching

* notebook

* sparse_done

* images

* notebook

* add_image_gif

* with_Gif

* add_contr_image

* notebook

* notebook

* notebook

* train_functions

* notebook

* min_retrieval_length

* pandas_option

* notebook

* min_retrieval_length

* notebook

* notebook

* eval_Retriever

* notebook

* images

* notebook

* add_example

* add_example

* notebook

* fireworks

* notebook

* notebook

* joe's notebook comments

* app_update

* notebook

* notebook_link

* captions

* notebook

* assing RetriBert model

* add RetriBert to Auto

* change AutoLMHead to AutoSeq2Seq

* notebook downloads from hf models

* style_black

* style_black

* app_update

* app_update

* fix_app_update

* style

* style

* isort

* Delete WikiELI5training.ipynb

* Delete evaluate_eli5.py

* Delete WikiELI5explore.ipynb

* Delete ExploreWikiELI5Support.html

* Delete explainlikeimfive.py

* Delete wiki_snippets.py

* children before parent

* children before parent

* style_black

* style_black_only

* isort

* isort_new

* Update src/transformers/modeling_retribert.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>

* typo fixes

* app_without_asset

* cleanup

* Delete ELI5animation.gif

* Delete ELI5contrastive.svg

* Delete ELI5wiki_index.svg

* Delete choco_bis.svg

* Delete fireworks.gif

* Delete huggingface_logo.jpg

* Delete huggingface_logo.svg

* Delete Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb

* Delete eli5_app.py

* Delete eli5_utils.py

* readme

* Update README.md

* unused imports

* moved_info

* default_beam

* ftuned model

* disclaimer

* Update src/transformers/modeling_retribert.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* black

* add_doc

* names

* isort_Examples

* isort_Examples

* Add doc to index
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent c3e60749
......@@ -111,3 +111,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
model_doc/reformer
model_doc/marian
model_doc/longformer
model_doc/retribert
RetriBERT
----------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~
The RetriBERT model was proposed in the blog post
`Explain Anything Like I'm Five: A Model for Open Domain Long Form Question Answering <https://yjernite.github.io/lfqa.html>`__,
RetriBERT is a small model that uses either a single or pair of Bert encoders with lower-dimension projection for dense semantic indexing of text.
Code to train and use the model can be found `here <https://github.com/huggingface/transformers/tree/master/examples/distillation>`_.
RetriBertConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertConfig
:members:
RetriBertTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertTokenizer
:members:
RetriBertTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertTokenizerFast
:members:
RetriBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertModel
:members:
# Long Form Question Answering
This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries.
You can use these mothods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html).
import numpy as np
import torch
import faiss
import nlp
import streamlit as st
import transformers
from elasticsearch import Elasticsearch
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
qa_s2s_generate,
query_es_index,
query_qa_dense_index,
)
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_TYPE = "bart"
LOAD_DENSE_INDEX = True
@st.cache(allow_output_mutation=True)
def load_models():
if LOAD_DENSE_INDEX:
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
_ = qar_model.eval()
else:
qar_tokenizer, qar_model = (None, None)
if MODEL_TYPE == "bart":
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
s2s_model.load_state_dict(save_dict["model"])
_ = s2s_model.eval()
else:
s2s_tokenizer, s2s_model = make_qa_s2s_model(
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
)
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
@st.cache(allow_output_mutation=True)
def load_indexes():
if LOAD_DENSE_INDEX:
faiss_res = faiss.StandardGpuResources()
wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
wiki40b_passage_reps = np.memmap(
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
dtype="float32",
mode="r",
shape=(wiki40b_passages.num_rows, 128),
)
wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
else:
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
@st.cache(allow_output_mutation=True)
def load_train_data():
eli5 = nlp.load_dataset("eli5", name="LFQA_reddit")
eli5_train = eli5["train_eli5"]
eli5_train_q_reps = np.memmap(
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
)
eli5_train_q_index = faiss.IndexFlatIP(128)
eli5_train_q_index.add(eli5_train_q_reps)
return (eli5_train, eli5_train_q_index)
passages, gpu_dense_index, es_client = load_indexes()
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
eli5_train, eli5_train_q_index = load_train_data()
def find_nearest_training(question, n_results=10):
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
D, I = eli5_train_q_index.search(q_rep, n_results)
nn_examples = [eli5_train[int(i)] for i in I[0]]
return nn_examples
def make_support(question, source="wiki40b", method="dense", n_results=10):
if source == "none":
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
else:
if method == "dense":
support_doc, hit_lst = query_qa_dense_index(
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
)
else:
support_doc, hit_lst = query_es_index(
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
)
support_list = [
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
]
question_doc = "question: {} context: {}".format(question, support_doc)
return question_doc, support_list
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
def answer_question(
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
):
with torch.no_grad():
answer = qa_s2s_generate(
question_doc,
s2s_model,
s2s_tokenizer,
num_answers=1,
num_beams=n_beams,
min_len=min_len,
max_len=max_len,
do_sample=sampling,
temp=temp,
top_p=top_p,
top_k=None,
max_input_length=1024,
device="cuda:0",
)[0]
return (answer, support_list)
st.title("Long Form Question Answering with ELI5")
# Start sidebar
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
header_full = """
<html>
<head>
<style>
.img-container {
padding-left: 90px;
padding-right: 90px;
padding-top: 50px;
padding-bottom: 50px;
background-color: #f0f3f9;
}
</style>
</head>
<body>
<span class="img-container"> <!-- Inline parent element -->
%s
</span>
</body>
</html>
""" % (
header_html,
)
st.sidebar.markdown(
header_full, unsafe_allow_html=True,
)
# Long Form QA with ELI5 and Wikipedia
description = """
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
a pre-processed fixed snapshot of Wikipedia.
"""
st.sidebar.markdown(description, unsafe_allow_html=True)
action_list = [
"Answer the question",
"View the retrieved document only",
"View the most similar ELI5 question and answer",
"Show me everything, please!",
]
demo_options = st.sidebar.checkbox("Demo options")
if demo_options:
action_st = st.sidebar.selectbox("", action_list, index=3,)
action = action_list.index(action_st)
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
show_passages = show_type == "Show full text of passages"
else:
action = 3
show_passages = True
retrieval_options = st.sidebar.checkbox("Retrieval options")
if retrieval_options:
retriever_info = """
### Information retriever options
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
"""
st.sidebar.markdown(retriever_info)
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
else:
wiki_source = "wiki40b"
index_type = "dense"
sampled = "beam"
n_beams = 2
min_len = 64
max_len = 256
top_p = None
temp = None
generate_options = st.sidebar.checkbox("Generation options")
if generate_options:
generate_info = """
### Answer generation options
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
**beam** search, or **sample** from the decoder's output probabilities.
"""
st.sidebar.markdown(generate_info)
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
min_len = st.sidebar.slider(
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
)
max_len = st.sidebar.slider(
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
)
if sampled == "beam":
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
else:
top_p = st.sidebar.slider(
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
)
temp = st.sidebar.slider(
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
)
n_beams = None
# start main text
questions_list = [
"<MY QUESTION>",
"How do people make chocolate?",
"Why do we get a fever when we are sick?",
"How can different animals perceive different colors?",
"What is natural language processing?",
"What's the best way to treat a sunburn?",
"What exactly are vitamins ?",
"How does nuclear energy provide electricity?",
"What's the difference between viruses and bacteria?",
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
"Why do people like drinking coffee even though it tastes so bad?",
"What happens when wine ages? How does it make the wine taste better?",
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
"How does New Zealand have so many large bird predators?",
]
question_s = st.selectbox(
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
)
if question_s == "<MY QUESTION>":
question = st.text_input("Enter your question here:", "")
else:
question = question_s
if st.button("Show me!"):
if action in [0, 1, 3]:
if index_type == "mixed":
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
support_list = []
for res_d, res_s in zip(support_list_dense, support_list_sparse):
if tuple(res_d) not in support_list:
support_list += [tuple(res_d)]
if tuple(res_s) not in support_list:
support_list += [tuple(res_s)]
support_list = support_list[:10]
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
else:
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
if action in [0, 3]:
answer, support_list = answer_question(
question_doc,
s2s_model,
s2s_tokenizer,
min_len=min_len,
max_len=int(max_len),
sampling=(sampled == "sampled"),
n_beams=n_beams,
top_p=top_p,
temp=temp,
)
st.markdown("### The model generated answer is:")
st.write(answer)
if action in [0, 1, 3] and wiki_source != "none":
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
for i, res in enumerate(support_list):
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
sec_titles = res[1].strip()
if sec_titles == "":
sections = "[{}]({})".format(res[0], wiki_url)
else:
sec_list = sec_titles.split(" & ")
sections = " & ".join(
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
)
st.markdown(
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
unsafe_allow_html=True,
)
if show_passages:
st.write(
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
)
if action in [2, 3]:
nn_train_list = find_nearest_training(question)
train_exple = nn_train_list[0]
st.markdown(
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
)
answers_st = [
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
if i == 0 or sc > 2
]
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
disclaimer = """
---
**Disclaimer**
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)
This diff is collapsed.
......@@ -36,6 +36,7 @@ from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -130,6 +131,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -356,6 +358,12 @@ if is_torch_available():
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_retribert import (
RetriBertPreTrainedModel,
RetriBertModel,
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
# Optimization
from .optimization import (
AdamW,
......
......@@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .configuration_marian import MarianConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -64,6 +65,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
......@@ -71,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict(
[
("retribert", RetriBertConfig,),
("t5", T5Config,),
("distilbert", DistilBertConfig,),
("albert", AlbertConfig,),
......
......@@ -28,6 +28,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
}
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RetriBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
# TODO: uploadto AWS
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
}
class RetriBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.RetriBertModel`.
It is used to instantiate a RetriBertModel model according to the specified arguments, defining the model
architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the BERT model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, optional, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
share_encoders (:obj:`bool`, optional, defaults to True):
Whether to use the same Bert-type encoder for the queries and document
projection_dim (:obj:`int`, optional, defaults to 128):
Final dimension of the query and document representation after projection
"""
model_type = "retribert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=8,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
share_encoders=True,
projection_dim=128,
pad_token_id=0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.share_encoders = share_encoders
self.projection_dim = projection_dim
......@@ -34,6 +34,7 @@ from .configuration_auto import (
LongformerConfig,
OpenAIGPTConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -111,6 +112,7 @@ from .modeling_longformer import (
from .modeling_marian import MarianMTModel
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
from .modeling_retribert import RetriBertModel
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaForMultipleChoice,
......@@ -151,6 +153,7 @@ logger = logging.getLogger(__name__)
MODEL_MAPPING = OrderedDict(
[
(RetriBertConfig, RetriBertModel),
(T5Config, T5Model),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
......@@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict(
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(RetriBertConfig, RetriBertModel),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForPreTraining),
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RetriBERT model
"""
import logging
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings
from .modeling_bert import BertLayerNorm, BertModel
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"yjernite/retribert-base-uncased",
# See all RetriBert models at https://huggingface.co/models?filter=retribert
]
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class RetriBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = RetriBertConfig
load_tf_weights = None
base_model_prefix = "retribert"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
RETRIBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
@add_start_docstrings(
"""Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,
)
class RetriBertModel(RetriBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.projection_dim = config.projection_dim
self.bert_query = BertModel(config)
self.bert_doc = None if config.share_encoders else BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
self.init_weights()
def embed_sentences_checkpointed(
self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,
):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
return sent_encoder(input_ids, attention_mask=attention_mask)[1]
else:
# prepare implicit variables
device = input_ids.device
input_shape = input_ids.size()
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
)
# define function for cehckpointing
def partial_encode(*inputs):
encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
sequence_output = encoder_outputs[0]
pooled_output = sent_encoder.pooler(sequence_output)
return pooled_output
# run embedding layer on everything at once
embedding_output = sent_encoder.embeddings(
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
)
# run encoding and pooling on one mini-batch at a time
pooled_output_list = []
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)
def embed_questions(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
):
q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)
return self.project_query(q_reps)
def embed_answers(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
):
a_reps = self.embed_sentences_checkpointed(
input_ids,
attention_mask,
self.bert_query if self.bert_doc is None else self.bert_doc,
checkpoint_batch_size,
)
return self.project_doc(a_reps)
def forward(
self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
):
r"""
Args:
input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the queries in a batch.
Indices can be obtained using :class:`transformers.RetriBertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on queries padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the documents in a batch.
attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on documents padding token indices.
checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`):
If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time
on the GPU. All query representations are still compared to all document representations in the batch.
Return:
:obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document
and each cocument to its corresponding query in the batch
"""
device = input_ids_query.device
q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)
a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)
compare_scores = torch.mm(q_reps, a_reps.t())
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
loss = (loss_qa + loss_aq) / 2
return loss
......@@ -32,6 +32,7 @@ from .configuration_auto import (
LongformerConfig,
OpenAIGPTConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -55,6 +56,7 @@ from .tokenization_longformer import LongformerTokenizer
from .tokenization_marian import MarianTokenizer
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -68,6 +70,7 @@ logger = logging.getLogger(__name__)
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(T5Config, (T5Tokenizer, None)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
......
......@@ -33,6 +33,7 @@ _all_bart_models = [
"facebook/bart-large-mnli",
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"yjernite/bart_eli5",
]
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RetriBERT."""
import logging
from .tokenization_bert import BertTokenizer, BertTokenizerFast
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"yjernite/retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"yjernite/retribert-base-uncased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"yjernite/retribert-base-uncased": {"do_lower_case": True},
}
class RetriBertTokenizer(BertTokenizer):
r"""
Constructs a retribert.
:class:`~transformers.retribert is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
class RetriBertTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" RetriBertTokenizerFast (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment