t5_utils.py 887 Bytes
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
def mean_3class_f1(predictions, references):  # This is a passthrough function
lintangsutawika's avatar
lintangsutawika committed
2
    string_label = ["entailment", "contradiction", "neutral"]
lintangsutawika's avatar
update  
lintangsutawika committed
3
4
5
    predictions = (
        string_label.index(predictions[0]) if predictions[0] in string_label else 0
    )
lintangsutawika's avatar
lintangsutawika committed
6
7
8
9
    references = string_label.index(references[0])

    return (predictions, references)

lintangsutawika's avatar
lintangsutawika committed
10

lintangsutawika's avatar
lintangsutawika committed
11
12
13
14
15
16
17
18
19
20
21
22
def agg_mean_3class_f1(items):
    predictions, references = zip(*items)

    """Computes the unweighted average of the F1 per class."""
    metric_str = "fbeta_score"
    metric_fn_kwargs = {
        "beta": 1,
        "labels": range(3),
        "average": "macro",
    }

    def _fn(predictions, references):
23
24
        import sklearn.metrics

lintangsutawika's avatar
lintangsutawika committed
25
26
27
28
29
        metric_fn = getattr(sklearn.metrics, metric_str)
        metric_val = metric_fn(references, predictions, **metric_fn_kwargs)
        return metric_val

    return _fn(predictions, references)