evaluate.py 1.19 KB
Newer Older
zihanl's avatar
zihanl committed
1

zihanl's avatar
zihanl committed
2
3
"""Model evaluation"""

zihanl's avatar
zihanl committed
4
5
from megatron import get_args
from megatron import print_rank_0
zihanl's avatar
zihanl committed
6
from tasks.knwl_dialo.metrics import F1Metric
zihanl's avatar
zihanl committed
7
8
9
from tqdm import tqdm


zihanl's avatar
zihanl committed
10
def evaluate_f1(guess_file, answer_file):
zihanl's avatar
zihanl committed
11
    """Evaluating F1 Score"""
zihanl's avatar
zihanl committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

    guess_list = []
    print_rank_0('reading %s' % guess_file)
    with open(guess_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if "<|endoftext|>" in line:
                line = line.replace("<|endoftext|>", "")
            guess_list.append(line)

    answer_list = []
    print_rank_0('reading %s' % answer_file)
    with open(answer_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if line == "no_passages_used":
                line = ""
            answer_list.append(line)

    assert len(guess_list) == len(answer_list), \
        "lengths of guess and answer are different!"

zihanl's avatar
zihanl committed
34
    precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
zihanl's avatar
zihanl committed
35
36
37
38
39
40
41
    print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))

    print_rank_0('done :-)')


def main():
    args = get_args()
zihanl's avatar
zihanl committed
42
    
zihanl's avatar
zihanl committed
43
    evaluate_f1(args.guess_file, args.answer_file)
zihanl's avatar
zihanl committed
44