Unverified Commit 49c52025 authored by Yacine Jernite's avatar Yacine Jernite Committed by GitHub
Browse files

Eli5 examples (#4968)



* add eli5 examples

* add dense query script

* query_di

* merging

* merging

* add_utils

* adds nearest neighbor wikipedia

* batch queries

* training_retriever

* new notebooks

* moved retriever traiing script

* finished wiki40b

* max_len_fix

* train_s2s

* retriever_batch_checkpointing

* cleanup

* merge

* dim_fix

* fix_indexer

* fix_wiki40b_snippets

* fix_embed_for_r

* fp32 index

* fix_sparse_q

* joint_training

* remove obsolete datasets

* add_passage_nn_results

* add_passage_nn_results

* add_batch_nn

* add_batch_nn

* add_data_scripts

* notebook

* notebook

* notebook

* fix_multi_gpu

* add_app

* full_caching

* full_caching

* notebook

* sparse_done

* images

* notebook

* add_image_gif

* with_Gif

* add_contr_image

* notebook

* notebook

* notebook

* train_functions

* notebook

* min_retrieval_length

* pandas_option

* notebook

* min_retrieval_length

* notebook

* notebook

* eval_Retriever

* notebook

* images

* notebook

* add_example

* add_example

* notebook

* fireworks

* notebook

* notebook

* joe's notebook comments

* app_update

* notebook

* notebook_link

* captions

* notebook

* assing RetriBert model

* add RetriBert to Auto

* change AutoLMHead to AutoSeq2Seq

* notebook downloads from hf models

* style_black

* style_black

* app_update

* app_update

* fix_app_update

* style

* style

* isort

* Delete WikiELI5training.ipynb

* Delete evaluate_eli5.py

* Delete WikiELI5explore.ipynb

* Delete ExploreWikiELI5Support.html

* Delete explainlikeimfive.py

* Delete wiki_snippets.py

* children before parent

* children before parent

* style_black

* style_black_only

* isort

* isort_new

* Update src/transformers/modeling_retribert.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>

* typo fixes

* app_without_asset

* cleanup

* Delete ELI5animation.gif

* Delete ELI5contrastive.svg

* Delete ELI5wiki_index.svg

* Delete choco_bis.svg

* Delete fireworks.gif

* Delete huggingface_logo.jpg

* Delete huggingface_logo.svg

* Delete Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb

* Delete eli5_app.py

* Delete eli5_utils.py

* readme

* Update README.md

* unused imports

* moved_info

* default_beam

* ftuned model

* disclaimer

* Update src/transformers/modeling_retribert.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* black

* add_doc

* names

* isort_Examples

* isort_Examples

* Add doc to index
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent c3e60749
......@@ -111,3 +111,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
model_doc/reformer
model_doc/marian
model_doc/longformer
model_doc/retribert
RetriBERT
----------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~
The RetriBERT model was proposed in the blog post
`Explain Anything Like I'm Five: A Model for Open Domain Long Form Question Answering <https://yjernite.github.io/lfqa.html>`__,
RetriBERT is a small model that uses either a single or pair of Bert encoders with lower-dimension projection for dense semantic indexing of text.
Code to train and use the model can be found `here <https://github.com/huggingface/transformers/tree/master/examples/distillation>`_.
RetriBertConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertConfig
:members:
RetriBertTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertTokenizer
:members:
RetriBertTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertTokenizerFast
:members:
RetriBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RetriBertModel
:members:
# Long Form Question Answering
This folder contains the code for the Long Form Question answering [demo](http://35.226.96.115:8080/) as well as methods to train and use a fully end-to-end Long Form Question Answering system using the [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries.
You can use these mothods to train your own system by following along the associate [notebook](https://github.com/huggingface/notebooks/blob/master/longform-qa/Long_Form_Question_Answering_with_ELI5_and_Wikipedia.ipynb) or [blog post](https://yjernite.github.io/lfqa.html).
import numpy as np
import torch
import faiss
import nlp
import streamlit as st
import transformers
from elasticsearch import Elasticsearch
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
qa_s2s_generate,
query_es_index,
query_qa_dense_index,
)
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_TYPE = "bart"
LOAD_DENSE_INDEX = True
@st.cache(allow_output_mutation=True)
def load_models():
if LOAD_DENSE_INDEX:
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
_ = qar_model.eval()
else:
qar_tokenizer, qar_model = (None, None)
if MODEL_TYPE == "bart":
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
s2s_model.load_state_dict(save_dict["model"])
_ = s2s_model.eval()
else:
s2s_tokenizer, s2s_model = make_qa_s2s_model(
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
)
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
@st.cache(allow_output_mutation=True)
def load_indexes():
if LOAD_DENSE_INDEX:
faiss_res = faiss.StandardGpuResources()
wiki40b_passages = nlp.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
wiki40b_passage_reps = np.memmap(
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
dtype="float32",
mode="r",
shape=(wiki40b_passages.num_rows, 128),
)
wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
else:
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
@st.cache(allow_output_mutation=True)
def load_train_data():
eli5 = nlp.load_dataset("eli5", name="LFQA_reddit")
eli5_train = eli5["train_eli5"]
eli5_train_q_reps = np.memmap(
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
)
eli5_train_q_index = faiss.IndexFlatIP(128)
eli5_train_q_index.add(eli5_train_q_reps)
return (eli5_train, eli5_train_q_index)
passages, gpu_dense_index, es_client = load_indexes()
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
eli5_train, eli5_train_q_index = load_train_data()
def find_nearest_training(question, n_results=10):
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
D, I = eli5_train_q_index.search(q_rep, n_results)
nn_examples = [eli5_train[int(i)] for i in I[0]]
return nn_examples
def make_support(question, source="wiki40b", method="dense", n_results=10):
if source == "none":
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
else:
if method == "dense":
support_doc, hit_lst = query_qa_dense_index(
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
)
else:
support_doc, hit_lst = query_es_index(
question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results,
)
support_list = [
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
]
question_doc = "question: {} context: {}".format(question, support_doc)
return question_doc, support_list
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
def answer_question(
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
):
with torch.no_grad():
answer = qa_s2s_generate(
question_doc,
s2s_model,
s2s_tokenizer,
num_answers=1,
num_beams=n_beams,
min_len=min_len,
max_len=max_len,
do_sample=sampling,
temp=temp,
top_p=top_p,
top_k=None,
max_input_length=1024,
device="cuda:0",
)[0]
return (answer, support_list)
st.title("Long Form Question Answering with ELI5")
# Start sidebar
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
header_full = """
<html>
<head>
<style>
.img-container {
padding-left: 90px;
padding-right: 90px;
padding-top: 50px;
padding-bottom: 50px;
background-color: #f0f3f9;
}
</style>
</head>
<body>
<span class="img-container"> <!-- Inline parent element -->
%s
</span>
</body>
</html>
""" % (
header_html,
)
st.sidebar.markdown(
header_full, unsafe_allow_html=True,
)
# Long Form QA with ELI5 and Wikipedia
description = """
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
a pre-processed fixed snapshot of Wikipedia.
"""
st.sidebar.markdown(description, unsafe_allow_html=True)
action_list = [
"Answer the question",
"View the retrieved document only",
"View the most similar ELI5 question and answer",
"Show me everything, please!",
]
demo_options = st.sidebar.checkbox("Demo options")
if demo_options:
action_st = st.sidebar.selectbox("", action_list, index=3,)
action = action_list.index(action_st)
show_type = st.sidebar.selectbox("", ["Show full text of passages", "Show passage section titles"], index=0,)
show_passages = show_type == "Show full text of passages"
else:
action = 3
show_passages = True
retrieval_options = st.sidebar.checkbox("Retrieval options")
if retrieval_options:
retriever_info = """
### Information retriever options
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
"""
st.sidebar.markdown(retriever_info)
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
else:
wiki_source = "wiki40b"
index_type = "dense"
sampled = "beam"
n_beams = 2
min_len = 64
max_len = 256
top_p = None
temp = None
generate_options = st.sidebar.checkbox("Generation options")
if generate_options:
generate_info = """
### Answer generation options
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
**beam** search, or **sample** from the decoder's output probabilities.
"""
st.sidebar.markdown(generate_info)
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
min_len = st.sidebar.slider(
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
)
max_len = st.sidebar.slider(
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
)
if sampled == "beam":
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
else:
top_p = st.sidebar.slider(
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
)
temp = st.sidebar.slider(
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
)
n_beams = None
# start main text
questions_list = [
"<MY QUESTION>",
"How do people make chocolate?",
"Why do we get a fever when we are sick?",
"How can different animals perceive different colors?",
"What is natural language processing?",
"What's the best way to treat a sunburn?",
"What exactly are vitamins ?",
"How does nuclear energy provide electricity?",
"What's the difference between viruses and bacteria?",
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
"Why do people like drinking coffee even though it tastes so bad?",
"What happens when wine ages? How does it make the wine taste better?",
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
"How does New Zealand have so many large bird predators?",
]
question_s = st.selectbox(
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", questions_list, index=1,
)
if question_s == "<MY QUESTION>":
question = st.text_input("Enter your question here:", "")
else:
question = question_s
if st.button("Show me!"):
if action in [0, 1, 3]:
if index_type == "mixed":
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
support_list = []
for res_d, res_s in zip(support_list_dense, support_list_sparse):
if tuple(res_d) not in support_list:
support_list += [tuple(res_d)]
if tuple(res_s) not in support_list:
support_list += [tuple(res_s)]
support_list = support_list[:10]
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
else:
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
if action in [0, 3]:
answer, support_list = answer_question(
question_doc,
s2s_model,
s2s_tokenizer,
min_len=min_len,
max_len=int(max_len),
sampling=(sampled == "sampled"),
n_beams=n_beams,
top_p=top_p,
temp=temp,
)
st.markdown("### The model generated answer is:")
st.write(answer)
if action in [0, 1, 3] and wiki_source != "none":
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
for i, res in enumerate(support_list):
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
sec_titles = res[1].strip()
if sec_titles == "":
sections = "[{}]({})".format(res[0], wiki_url)
else:
sec_list = sec_titles.split(" & ")
sections = " & ".join(
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
)
st.markdown(
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
unsafe_allow_html=True,
)
if show_passages:
st.write(
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
)
if action in [2, 3]:
nn_train_list = find_nearest_training(question)
train_exple = nn_train_list[0]
st.markdown(
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
)
answers_st = [
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
if i == 0 or sc > 2
]
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
disclaimer = """
---
**Disclaimer**
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)
import functools
import math
import os # noqa: F401
from random import choice, randint
from time import time
import numpy as np
import torch
import torch.utils.checkpoint as checkpoint
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
import faiss # noqa: F401
import nlp # noqa: F401
import pandas as pd
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
pd.set_option("display.max_colwidth", None)
###############
# Sparse index
###############
def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):
index_config = {
"settings": {
"number_of_shards": 1,
"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
},
"mappings": {
"properties": {
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
}
},
}
es_client.indices.create(index=index_name, body=index_config)
number_of_docs = passages_dset.num_rows
progress = tqdm(unit="docs", total=number_of_docs)
successes = 0
def passage_generator():
for passage in passages_dset:
yield passage
# create the ES index
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
progress.update(1)
successes += ok
print("Indexed %d documents" % (successes,))
def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):
q = question.lower()
banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]
q = " ".join([w for w in q.split() if w not in banned])
response = es_client.search(
index=index_name,
body={
"query": {
"multi_match": {
"query": q,
"fields": ["article_title", "section_title", "passage_text^2"],
"type": "cross_fields",
}
},
"size": 2 * n_results,
},
)
hits = response["hits"]["hits"]
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
for r, hit in zip(res_list, hits):
r["passage_id"] = hit["_id"]
r["score"] = hit["_score"]
r["passage_text"] = hit["_source"]["passage_text"]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
return support_doc, res_list
###############
# ELI5 retriever training
###############
class ELI5DatasetQARetriver(Dataset):
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
self.data = examples_array
self.answer_thres = extra_answer_threshold
self.min_length = min_answer_length
self.training = training
self.n_samples = self.data.num_rows if n_samples is None else n_samples
def __len__(self):
return self.n_samples
def make_example(self, idx):
example = self.data[idx]
question = example["title"]
if self.training:
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
answer_tab = choice(answers).split(" ")
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
answer_span = " ".join(answer_tab[start_idx:])
else:
answer_span = example["answers"]["text"][0]
return (question, answer_span)
def __getitem__(self, idx):
return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module):
def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder
self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
else:
# prepare implicit variables
device = input_ids.device
input_shape = input_ids.size()
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
)
# define function for checkpointing
def partial_encode(*inputs):
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
sequence_output = encoder_outputs[0]
pooled_output = self.sent_encoder.pooler(sequence_output)
return pooled_output
# run embedding layer on everything at once
embedding_output = self.sent_encoder.embeddings(
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
)
# run encoding and pooling on one mini-batch at a time
pooled_output_list = []
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
return self.project_q(q_reps)
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
return self.project_a(a_reps)
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
device = q_ids.device
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
compare_scores = torch.mm(q_reps, a_reps.t())
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
loss = (loss_qa + loss_aq) / 2
return loss
def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to(device)
# run bert_model on a dummy batch to get output dimension
d_ids = torch.LongTensor(
[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]
).to(device)
d_mask = torch.LongTensor([[1]]).to(device)
sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]
qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)
if from_file is not None:
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
qa_embedder.load_state_dict(param_dict["model"])
return tokenizer, qa_embedder
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
return (q_ids, q_mask, a_ids, a_mask)
def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
model.train()
# make iterator
train_sampler = RandomSampler(dataset)
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, batch in enumerate(epoch_iterator):
q_ids, q_mask, a_ids, a_mask = batch
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
loss = pre_loss.sum()
# optimizer
loss.backward()
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
model.train()
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
# make iterator
train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
data_loaders = [
DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
for dataset, train_sampler in zip(dataset_list, train_samplers)
]
iterators = [iter(dloader) for dloader in data_loaders]
joint_iter = zip(*iterators)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, (batches,) in enumerate(zip(joint_iter)):
for batch in batches:
q_ids, q_mask, a_ids, a_mask = batch
loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
# optimizer
loss.backward()
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def evaluate_qa_retriever(model, dataset, tokenizer, args):
model.eval()
# make iterator
eval_sampler = SequentialSampler(dataset)
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
tot_loss = 0.0
with torch.no_grad():
for step, batch in enumerate(epoch_iterator):
q_ids, q_mask, a_ids, a_mask = batch
loss = model(q_ids, q_mask, a_ids, a_mask)
tot_loss += loss.item()
return tot_loss / (step + 1)
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
qar_optimizer,
num_warmup_steps=100,
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
)
for e in range(qar_args.num_epochs):
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
m_save_dict = {
"model": qar_model.state_dict(),
"optimizer": qar_optimizer.state_dict(),
"scheduler": qar_scheduler.state_dict(),
}
print("Saving model {}".format(qar_args.model_save_name))
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))
###############
# ELI5 seq2seq model training
###############
class ELI5DatasetS2S(Dataset):
def __init__(
self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
):
self.training = training
self.data = examples_array
self.make_doc_function = make_doc_fun
self.document_cache = {} if document_cache is None else document_cache
assert not (make_doc_fun is None and document_cache is None)
# make index of specific question-answer pairs from multi-answers
if self.training:
self.qa_id_list = [
(i, j)
for i, qa in enumerate(self.data)
for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
if j == 0 or sc >= extra_answer_threshold
]
else:
self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]
def __len__(self):
return len(self.qa_id_list)
def make_example(self, idx):
i, j = self.qa_id_list[idx]
example = self.data[i]
question = example["title"] + " " + example["selftext"]
answer = example["answers"]["text"][j]
q_id = example["q_id"]
if self.make_doc_function is not None:
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
document = self.document_cache[q_id]
in_st = "question: {} context: {}".format(
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
)
out_st = answer
return (in_st, out_st)
def __getitem__(self, idx):
return self.make_example(idx)
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if from_file is not None:
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
model.load_state_dict(param_dict["model"])
return tokenizer, model
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
lm_labels = a_ids[:, 1:].contiguous().clone()
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
model_inputs = {
"input_ids": q_ids,
"attention_mask": q_mask,
"decoder_input_ids": a_ids[:, :-1].contiguous(),
"lm_labels": lm_labels,
}
return model_inputs
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
model.train()
# make iterator
if curriculum:
train_sampler = SequentialSampler(dataset)
else:
train_sampler = RandomSampler(dataset)
model_collate_fn = functools.partial(
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
for step, batch_inputs in enumerate(epoch_iterator):
pre_loss = model(**batch_inputs)[0]
loss = pre_loss.sum() / pre_loss.shape[0]
loss.backward()
# optimizer
if step % args.backward_freq == 0:
optimizer.step()
scheduler.step()
model.zero_grad()
# some printing within the epoch
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0 or step == 1:
print(
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
loc_loss = 0
loc_steps = 0
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
model.eval()
# make iterator
train_sampler = SequentialSampler(dataset)
model_collate_fn = functools.partial(
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
# accumulate loss since last print
loc_steps = 0
loc_loss = 0.0
st_time = time()
with torch.no_grad():
for step, batch_inputs in enumerate(epoch_iterator):
pre_loss = model(**batch_inputs)[0]
loss = pre_loss.sum() / pre_loss.shape[0]
loc_loss += loss.item()
loc_steps += 1
if step % args.print_freq == 0:
print(
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
)
)
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
s2s_optimizer,
num_warmup_steps=400,
num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
)
for e in range(s2s_args.num_epochs):
train_qa_s2s_epoch(
qa_s2s_model,
s2s_train_dset,
qa_s2s_tokenizer,
s2s_optimizer,
s2s_scheduler,
s2s_args,
e,
curriculum=(e == 0),
)
m_save_dict = {
"model": qa_s2s_model.state_dict(),
"optimizer": s2s_optimizer.state_dict(),
"scheduler": s2s_scheduler.state_dict(),
}
print("Saving model {}".format(s2s_args.model_save_name))
eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
# generate answer from input "question: ... context: <p> ..."
def qa_s2s_generate(
question_doc,
qa_s2s_model,
qa_s2s_tokenizer,
num_answers=1,
num_beams=None,
min_len=64,
max_len=256,
do_sample=False,
temp=1.0,
top_p=None,
top_k=None,
max_input_length=512,
device="cuda:0",
):
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
generated_ids = qa_s2s_model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
min_length=min_len,
max_length=max_len,
do_sample=do_sample,
early_stopping=True,
num_beams=1 if do_sample else n_beams,
temperature=temp,
top_k=top_k,
top_p=top_p,
eos_token_id=qa_s2s_tokenizer.eos_token_id,
no_repeat_ngram_size=3,
num_return_sequences=num_answers,
decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
)
return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]
###############
# ELI5-trained retrieval model usage
###############
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
with torch.no_grad():
a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
return a_reps.numpy()
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
with torch.no_grad():
q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
return q_reps.numpy()
def make_qa_dense_index(
qa_embedder,
tokenizer,
passages_dset,
batch_size=512,
max_length=128,
index_name="kilt_passages_reps.dat",
dtype="float32",
device="cuda:0",
):
st_time = time()
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
n_batches = math.ceil(passages_dset.num_rows / batch_size)
for i in range(n_batches):
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
fp[i * batch_size : (i + 1) * batch_size] = reps
if i % 50 == 0:
print(i, time() - st_time)
def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
total_retriever_time = 0.0
total_retriever_score = 0.0
st_time = time()
for i, (question, answer) in enumerate(qa_list):
r_time = time()
retrieved_passages = retriever_func(question, n_ret)
total_retriever_time += time() - r_time
total_retriever_score += scoring_func(retrieved_passages, answer)
if verbose and ((i + 1) % 500 == 0 or i <= 1):
print(
"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time
)
)
return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}
# build a support document for the question out of Wikipedia snippets
def query_qa_dense_index(
question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"
):
q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
D, I = wiki_index.search(q_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc in zip(res_list, D[0]):
r["score"] = float(sc)
return support_doc, res_list
def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)
D, I = wiki_index.search(q_rep, n_results)
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
support_doc_lst = [
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
for (res_passages, dl) in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
all_res_lists += [res_list[:]]
return support_doc_lst, all_res_lists
# find nearest neighbors of an answer or declarative text in Wikipedia snippets
def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
D, I = wiki_index.search(a_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc, i in zip(res_list, D[0], I[0]):
r["passage_id"] = int(i)
r["score"] = float(sc)
return support_doc, res_list
def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
D, I = wiki_index.search(a_reps, n_results)
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
support_doc_lst = [
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
r["score"] = float(sc)
all_res_lists += [res_list[:]]
return support_doc_lst, all_res_lists
......@@ -36,6 +36,7 @@ from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -130,6 +131,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -356,6 +358,12 @@ if is_torch_available():
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_retribert import (
RetriBertPreTrainedModel,
RetriBertModel,
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
# Optimization
from .optimization import (
AdamW,
......
......@@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .configuration_marian import MarianConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -64,6 +65,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
......@@ -71,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict(
[
("retribert", RetriBertConfig,),
("t5", T5Config,),
("distilbert", DistilBertConfig,),
("albert", AlbertConfig,),
......
......@@ -28,6 +28,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
}
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RetriBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
# TODO: uploadto AWS
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
}
class RetriBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.RetriBertModel`.
It is used to instantiate a RetriBertModel model according to the specified arguments, defining the model
architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the BERT model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, optional, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
share_encoders (:obj:`bool`, optional, defaults to True):
Whether to use the same Bert-type encoder for the queries and document
projection_dim (:obj:`int`, optional, defaults to 128):
Final dimension of the query and document representation after projection
"""
model_type = "retribert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=8,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
share_encoders=True,
projection_dim=128,
pad_token_id=0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.share_encoders = share_encoders
self.projection_dim = projection_dim
......@@ -34,6 +34,7 @@ from .configuration_auto import (
LongformerConfig,
OpenAIGPTConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -111,6 +112,7 @@ from .modeling_longformer import (
from .modeling_marian import MarianMTModel
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
from .modeling_retribert import RetriBertModel
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaForMultipleChoice,
......@@ -151,6 +153,7 @@ logger = logging.getLogger(__name__)
MODEL_MAPPING = OrderedDict(
[
(RetriBertConfig, RetriBertModel),
(T5Config, T5Model),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
......@@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict(
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(RetriBertConfig, RetriBertModel),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForPreTraining),
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RetriBERT model
"""
import logging
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings
from .modeling_bert import BertLayerNorm, BertModel
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"yjernite/retribert-base-uncased",
# See all RetriBert models at https://huggingface.co/models?filter=retribert
]
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class RetriBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = RetriBertConfig
load_tf_weights = None
base_model_prefix = "retribert"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
RETRIBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
@add_start_docstrings(
"""Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,
)
class RetriBertModel(RetriBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.projection_dim = config.projection_dim
self.bert_query = BertModel(config)
self.bert_doc = None if config.share_encoders else BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
self.init_weights()
def embed_sentences_checkpointed(
self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,
):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
return sent_encoder(input_ids, attention_mask=attention_mask)[1]
else:
# prepare implicit variables
device = input_ids.device
input_shape = input_ids.size()
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
)
# define function for cehckpointing
def partial_encode(*inputs):
encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
sequence_output = encoder_outputs[0]
pooled_output = sent_encoder.pooler(sequence_output)
return pooled_output
# run embedding layer on everything at once
embedding_output = sent_encoder.embeddings(
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
)
# run encoding and pooling on one mini-batch at a time
pooled_output_list = []
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)
def embed_questions(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
):
q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)
return self.project_query(q_reps)
def embed_answers(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
):
a_reps = self.embed_sentences_checkpointed(
input_ids,
attention_mask,
self.bert_query if self.bert_doc is None else self.bert_doc,
checkpoint_batch_size,
)
return self.project_doc(a_reps)
def forward(
self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
):
r"""
Args:
input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the queries in a batch.
Indices can be obtained using :class:`transformers.RetriBertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on queries padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the documents in a batch.
attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on documents padding token indices.
checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`):
If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time
on the GPU. All query representations are still compared to all document representations in the batch.
Return:
:obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document
and each cocument to its corresponding query in the batch
"""
device = input_ids_query.device
q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)
a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)
compare_scores = torch.mm(q_reps, a_reps.t())
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
loss = (loss_qa + loss_aq) / 2
return loss
......@@ -32,6 +32,7 @@ from .configuration_auto import (
LongformerConfig,
OpenAIGPTConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -55,6 +56,7 @@ from .tokenization_longformer import LongformerTokenizer
from .tokenization_marian import MarianTokenizer
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -68,6 +70,7 @@ logger = logging.getLogger(__name__)
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(T5Config, (T5Tokenizer, None)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
......
......@@ -33,6 +33,7 @@ _all_bart_models = [
"facebook/bart-large-mnli",
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"yjernite/bart_eli5",
]
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RetriBERT."""
import logging
from .tokenization_bert import BertTokenizer, BertTokenizerFast
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"yjernite/retribert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"yjernite/retribert-base-uncased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"yjernite/retribert-base-uncased": {"do_lower_case": True},
}
class RetriBertTokenizer(BertTokenizer):
r"""
Constructs a retribert.
:class:`~transformers.retribert is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
class RetriBertTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" RetriBertTokenizerFast (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment