utils.py 3.17 KB
Newer Older
1
import datasets
2
3
import evaluate

4
5
6
7
8
9
10

def strip(resps, docs):
    """
    Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
    """
    return map(lambda r: r[0].strip(), resps)

11

12
13
14
def dr_fr(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "dr_fr")

15

16
17
18
def dr_en(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "dr_en")

19

20
21
22
def dr_msa(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "dr_msa")

23

24
25
26
def fr_dr(dataset: datasets.Dataset):
    return dataset.filter(lambda x: x["direction"] == "fr_dr")

27
28

def en_dr(dataset: datasets.Dataset):
29
30
    return dataset.filter(lambda x: x["direction"] == "en_dr")

31

32
def msa_dr(dataset: datasets.Dataset):
33
34
35
    return dataset.filter(lambda x: x["direction"] == "msa_dr")


36
prompt_templates = {
37
38
39
40
41
42
43
44
    "fr_dr": "ترجم من الفرنساوية للدارجة:\n{0}",
    "dr_fr": "ترجم من الدارجة للفرنساوية:\n{0}",
    "en_dr": "ترجم من الإنجليزية للدارجة:\n{0}",
    "dr_en": "ترجم من الدارجة للإنجليزية:\n{0}",
    "msa_dr": "ترجم من الفصحى للدارجة:\n{0}",
    "dr_msa": "ترجم من الدارجة للفصحى:\n{0}",
}

45
46
47
48
49

def doc_to_text(doc):
    doc_text = doc["messages"][0]["content"]
    return doc_text

50

51
52
53
def doc_to_target(doc):
    return doc["messages"][1]["content"]

54

55
56
57
def bert(items):
    return items

58

59
def Average(lst):
60
61
62
    return sum(lst) / len(lst)


63
def camembert(items):
64
    bert_model = "almanach/camembert-base"
65
66
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
67
68
69
70
71
72
73
74
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

75
76

def darijabert(items):
77
    bert_model = "SI2M-Lab/DarijaBERT"
78
79
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
80
81
82
83
84
85
86
87
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

88
89
90
91
92

def arabert(items):
    bert_model = "aubmindlab/bert-base-arabert"
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
93
94
95
96
97
98
99
100
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

101
102
103
104
105

def bertbase(items):
    bert_model = "google-bert/bert-base-uncased"
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
106
107
108
109
110
111
112
113
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])

114
115

def mbert(items):
116
    bert_model = "google-bert/bert-base-multilingual-cased"
117
118
    bert_score = evaluate.load("bertscore")
    predictions, references = zip(*items)
119
120
121
122
123
124
125
    bert = bert_score.compute(
        predictions=predictions,
        references=references,
        model_type=bert_model,
        num_layers=12,
    )
    return Average(bert["f1"])