utils.py 553 Bytes
Newer Older
lintangsutawika's avatar
scrolls  
lintangsutawika committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import transformers.data.metrics.squad_metrics as squad_metrics

def process_docs_bool(dataset):

    dataset = process_docs_prepended_question(dataset)

    return dataset.filter(lambda doc: squad_metrics.normalize_answer(doc["output"]) in ["yes", "no"])

def process_docs_freeform(dataset):

    dataset = process_docs_prepended_question(dataset)

    return dataset.filter(lambda doc: squad_metrics.normalize_answer(doc["output"]) not in ["yes", "no"])

def f1(prediction, reference):
    return squad_metrics.compute_f1(prediction[0], reference[0])