utils.py 711 Bytes
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import transformers.data.metrics.squad_metrics as squad_metrics

def process_docs(dataset):

    dataset = process_docs_prepended_question(dataset)

    def _process_doc(doc):

        doc["is_yes_no"] = reduce(lambda prev, cur: prev and squad_metrics.normalize_answer(cur)
                                  in ["yes", "no"], doc["outputs"], True)

        return doc

    return dataset.map(_process_doc)

def process_results(doc, results):
    if doc["is_yes_no"]:
        prediction = " yes" if results[0] > results[1] else " no"
    elif len(results[0].strip()) == 0:
        prediction = "Unanswerable"
    else:
        prediction = results[0]
    return {
        "f1": (prediction, doc["outputs"])
    }