"sims/mem/memnic/memnic.c" did not exist on "2ad19dc5ae055c664a2fdb68e6f192645ff1dcd6"
utils.py 1.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import string

import evaluate


def clean_text(text: str) -> str:
    # Remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))

    # Remove newlines and multiple spaces
    text = text.replace("\n", " ").strip()
    text = " ".join(text.split()).strip()

    # lowercase
    text = text.lower()

    return text


def rouge1(items):
    """
    # passthrough for efficiency
    """
    return items


def average_len(items):
    """
    # passthrough for efficiency
    """
    return items


def rouge1_agg(items):
    """
    Higher is better
    """

    refs = list(zip(*items))[0]
    refs = [[clean_text(ref)] for ref in refs]
    # print("refs", refs)
    preds = [clean_text(x) for x in list(zip(*items))[1]]
    # print("preds", preds)
    rouge_scorer = evaluate.load("rouge")
    return rouge_scorer.compute(predictions=preds, references=refs)["rouge1"]


def average_len_agg(items):
    """
    Higher is better
    """

    preds = [clean_text(x) for x in list(zip(*items))[1]]

    return sum(len(x.split()) for x in preds) / len(preds)