eli5_app.py 13.1 KB
Newer Older
1
import datasets
Yacine Jernite's avatar
Yacine Jernite committed
2
import numpy as np
3
import streamlit as st
Yacine Jernite's avatar
Yacine Jernite committed
4
import torch
Patrick von Platen's avatar
Patrick von Platen committed
5
from elasticsearch import Elasticsearch
Yacine Jernite's avatar
Yacine Jernite committed
6

Ola Piktus's avatar
Ola Piktus committed
7
import faiss
Yacine Jernite's avatar
Yacine Jernite committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import transformers
from eli5_utils import (
    embed_questions_for_retrieval,
    make_qa_s2s_model,
    qa_s2s_generate,
    query_es_index,
    query_qa_dense_index,
)
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer


MODEL_TYPE = "bart"
LOAD_DENSE_INDEX = True


@st.cache(allow_output_mutation=True)
def load_models():
    if LOAD_DENSE_INDEX:
        qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
        qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
        _ = qar_model.eval()
    else:
        qar_tokenizer, qar_model = (None, None)
    if MODEL_TYPE == "bart":
        s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
        s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
        save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
        s2s_model.load_state_dict(save_dict["model"])
        _ = s2s_model.eval()
    else:
        s2s_tokenizer, s2s_model = make_qa_s2s_model(
            model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
        )
    return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)


@st.cache(allow_output_mutation=True)
def load_indexes():
    if LOAD_DENSE_INDEX:
        faiss_res = faiss.StandardGpuResources()
48
        wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
Yacine Jernite's avatar
Yacine Jernite committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        wiki40b_passage_reps = np.memmap(
            "wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
            dtype="float32",
            mode="r",
            shape=(wiki40b_passages.num_rows, 128),
        )
        wiki40b_index_flat = faiss.IndexFlatIP(128)
        wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
        wiki40b_gpu_index_flat.add(wiki40b_passage_reps)  # TODO fix for larger GPU
    else:
        wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
    es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
    return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)


@st.cache(allow_output_mutation=True)
def load_train_data():
66
    eli5 = datasets.load_dataset("eli5", name="LFQA_reddit")
Yacine Jernite's avatar
Yacine Jernite committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    eli5_train = eli5["train_eli5"]
    eli5_train_q_reps = np.memmap(
        "eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
    )
    eli5_train_q_index = faiss.IndexFlatIP(128)
    eli5_train_q_index.add(eli5_train_q_reps)
    return (eli5_train, eli5_train_q_index)


passages, gpu_dense_index, es_client = load_indexes()
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
eli5_train, eli5_train_q_index = load_train_data()


def find_nearest_training(question, n_results=10):
    q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
    D, I = eli5_train_q_index.search(q_rep, n_results)
    nn_examples = [eli5_train[int(i)] for i in I[0]]
    return nn_examples


def make_support(question, source="wiki40b", method="dense", n_results=10):
    if source == "none":
        support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
    else:
        if method == "dense":
            support_doc, hit_lst = query_qa_dense_index(
                question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
            )
        else:
            support_doc, hit_lst = query_es_index(
Lysandre's avatar
Lysandre committed
98
99
100
101
                question,
                es_client,
                index_name="english_wiki40b_snippets_100w",
                n_results=n_results,
Yacine Jernite's avatar
Yacine Jernite committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
            )
    support_list = [
        (res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
    ]
    question_doc = "question: {} context: {}".format(question, support_doc)
    return question_doc, support_list


@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
def answer_question(
    question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
):
    with torch.no_grad():
        answer = qa_s2s_generate(
            question_doc,
            s2s_model,
            s2s_tokenizer,
            num_answers=1,
            num_beams=n_beams,
            min_len=min_len,
            max_len=max_len,
            do_sample=sampling,
            temp=temp,
            top_p=top_p,
            top_k=None,
            max_input_length=1024,
            device="cuda:0",
        )[0]
    return (answer, support_list)


st.title("Long Form Question Answering with ELI5")

# Start sidebar
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
header_full = """
<html>
  <head>
    <style>
      .img-container {
        padding-left: 90px;
        padding-right: 90px;
        padding-top: 50px;
        padding-bottom: 50px;
        background-color: #f0f3f9;
      }
    </style>
  </head>
  <body>
    <span class="img-container"> <!-- Inline parent element -->
      %s
    </span>
  </body>
</html>
""" % (
    header_html,
)
st.sidebar.markdown(
Lysandre's avatar
Lysandre committed
160
161
    header_full,
    unsafe_allow_html=True,
Yacine Jernite's avatar
Yacine Jernite committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
)

# Long Form QA with ELI5 and Wikipedia
description = """
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
a pre-processed fixed snapshot of Wikipedia.
"""
st.sidebar.markdown(description, unsafe_allow_html=True)

action_list = [
    "Answer the question",
    "View the retrieved document only",
    "View the most similar ELI5 question and answer",
    "Show me everything, please!",
]
demo_options = st.sidebar.checkbox("Demo options")
if demo_options:
Lysandre's avatar
Lysandre committed
180
181
182
183
184
    action_st = st.sidebar.selectbox(
        "",
        action_list,
        index=3,
    )
Yacine Jernite's avatar
Yacine Jernite committed
185
    action = action_list.index(action_st)
Lysandre's avatar
Lysandre committed
186
187
188
189
190
    show_type = st.sidebar.selectbox(
        "",
        ["Show full text of passages", "Show passage section titles"],
        index=0,
    )
Yacine Jernite's avatar
Yacine Jernite committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    show_passages = show_type == "Show full text of passages"
else:
    action = 3
    show_passages = True

retrieval_options = st.sidebar.checkbox("Retrieval options")
if retrieval_options:
    retriever_info = """
    ### Information retriever options

    The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
    trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
    The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
    """
    st.sidebar.markdown(retriever_info)
    wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
    index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
else:
    wiki_source = "wiki40b"
    index_type = "dense"

sampled = "beam"
n_beams = 2
min_len = 64
max_len = 256
top_p = None
temp = None
generate_options = st.sidebar.checkbox("Generation options")
if generate_options:
    generate_info = """
    ### Answer generation options

    The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
    weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
    **beam** search, or **sample** from the decoder's output probabilities.
    """
    st.sidebar.markdown(generate_info)
    sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
    min_len = st.sidebar.slider(
        "Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
    )
    max_len = st.sidebar.slider(
        "Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
    )
    if sampled == "beam":
        n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
    else:
        top_p = st.sidebar.slider(
            "Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
        )
        temp = st.sidebar.slider(
            "Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
        )
        n_beams = None

# start main text
questions_list = [
    "<MY QUESTION>",
    "How do people make chocolate?",
    "Why do we get a fever when we are sick?",
    "How can different animals perceive different colors?",
    "What is natural language processing?",
    "What's the best way to treat a sunburn?",
    "What exactly are vitamins ?",
    "How does nuclear energy provide electricity?",
    "What's the difference between viruses and bacteria?",
    "Why are flutes classified as woodwinds when most of them are made out of metal ?",
    "Why do people like drinking coffee even though it tastes so bad?",
    "What happens when wine ages? How does it make the wine taste better?",
    "If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
    "How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
    "How does New Zealand have so many large bird predators?",
]
question_s = st.selectbox(
Lysandre's avatar
Lysandre committed
265
266
267
    "What would you like to ask? ---- select <MY QUESTION> to enter a new query",
    questions_list,
    index=1,
Yacine Jernite's avatar
Yacine Jernite committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
)
if question_s == "<MY QUESTION>":
    question = st.text_input("Enter your question here:", "")
else:
    question = question_s

if st.button("Show me!"):
    if action in [0, 1, 3]:
        if index_type == "mixed":
            _, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
            _, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
            support_list = []
            for res_d, res_s in zip(support_list_dense, support_list_sparse):
                if tuple(res_d) not in support_list:
                    support_list += [tuple(res_d)]
                if tuple(res_s) not in support_list:
                    support_list += [tuple(res_s)]
            support_list = support_list[:10]
            question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
        else:
            question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
    if action in [0, 3]:
        answer, support_list = answer_question(
            question_doc,
            s2s_model,
            s2s_tokenizer,
            min_len=min_len,
            max_len=int(max_len),
            sampling=(sampled == "sampled"),
            n_beams=n_beams,
            top_p=top_p,
            temp=temp,
        )
        st.markdown("### The model generated answer is:")
        st.write(answer)
    if action in [0, 1, 3] and wiki_source != "none":
        st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
        for i, res in enumerate(support_list):
            wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
            sec_titles = res[1].strip()
            if sec_titles == "":
                sections = "[{}]({})".format(res[0], wiki_url)
            else:
                sec_list = sec_titles.split(" & ")
                sections = " & ".join(
                    ["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
                )
            st.markdown(
                "{0:02d} - **Article**: {1:<18} <br>  _Section_: {2}".format(i + 1, res[0], sections),
                unsafe_allow_html=True,
            )
            if show_passages:
                st.write(
                    '> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
                )
    if action in [2, 3]:
        nn_train_list = find_nearest_training(question)
        train_exple = nn_train_list[0]
        st.markdown(
            "--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
        )
        answers_st = [
            "{}. {}".format(i + 1, "  \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
            for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
            if i == 0 or sc > 2
        ]
        st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))


disclaimer = """
---

**Disclaimer**

*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)