evaluate_vqav2.py 2.6 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import glob
import json

from open_flamingo.eval.vqa_metric import VQAEval


def merge_input_files(input_path):
    """Merge input files to a format compatible with the evaluator."""
    output_file_path = input_path + "-VQAv2-merged.json"

    pattern = input_path + "-VQAv2-[0-9].*jsonl"
    input_file_paths = glob.glob(pattern)

    results = []

    for input_file_path in input_file_paths:
        with open(input_file_path, "r") as input_file:
            for line in input_file:
                res = json.loads(line)
                res["question_id"] = res["sample_id"]

                results.append(res)

    with open(output_file_path, "w") as output_file:
        json.dump(results, output_file)

    return output_file_path


def is_number(n: str):
    try:
        float(n)
        return True
    except ValueError:
        return False


def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
    """Compute VQA accuracy."""
    merged_results = json.load(open(result_file))

    vqa = VQAEval(vqa=None, vqaRes=None)
    all_acc = []
    for res in merged_results:
        pred = res["answer"]
        pred = vqa.processPunctuation(pred)
        pred = vqa.processDigitArticle(pred)

        gt = res["gt_answer"]
        gt = [vqa.processPunctuation(ans) for ans in gt]
        gt = [vqa.processDigitArticle(ans) for ans in gt]

        # ChartQA uses relaxed accuracy:
        # "We consider an answer to be correct if it is within 5% of the gold answer.
        #  For non-numeric answers, we still need an exact match to consider an answer to be correct."
        if use_chartqa_metric:
            acc = 0.
            assert len(gt) == 1, "expected exactly one groundtruth answer."
            gt = gt[0]

            if is_number(pred) and is_number(gt):
                pred = float(pred)
                gt = float(gt)
                if pred >= (gt * 0.95) and pred <= (gt * 1.05):
                    acc = 1.0
            elif pred == gt:
                acc = 1.0

            all_acc.append(acc)
        else:
            num_match = sum([pred == ans for ans in gt])
            acc = min(1.0, num_match / 3.0)
            all_acc.append(acc)

    acc_avg = sum(all_acc) / len(all_acc) * 100
    print(f"===== Accuracy {acc_avg:.2f}% =====")


def vqav2_eval(input_path):
    """Run VQAv2 evaluation."""
    result_file = merge_input_files(input_path)
    compute_vqa_accuracy(result_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-path', type=str, help="Path to input file(s)")
    args = parser.parse_args()

    vqav2_eval(args.input_path)