utils.py 1.67 KB
Newer Older
1
import datasets
2
import evaluate
3
4
5
6
7
8
9
10


def strip(resps, docs):
    """
    Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
    """
    return map(lambda r: r[0].strip(), resps)

11

12
13
14
def dr_ar(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "dr_ar")

15

16
17
18
def ar_dr(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "ar_dr")

19

20
21
22
23
def doc_to_text(doc):
    doc_text = doc["messages"][0]["content"]
    return doc_text

24

25
26
27
def doc_to_target(doc):
    return doc["messages"][1]["content"]

28

29
30
31
def bert(items):
    return items

32

33
def Average(lst):
34
35
36
    return sum(lst) / len(lst)


37
38
39
40
def arabizibert(items):
    bert_model = "SI2M-Lab/DarijaBERT-arabizi"
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
41
42
43
44
45
46
47
48
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

49
50

def darijabert(items):
51
    bert_model = "SI2M-Lab/DarijaBERT"
52
53
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
54
55
56
57
58
59
60
61
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

62
63

def mbert(items):
64
    bert_model = "google-bert/bert-base-multilingual-cased"
65
66
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
67
68
69
70
71
72
73
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])