test_gqa.py 959 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.evaluator import Evaluator

from mmpretrain.structures import DataSample


class TestScienceQAMetric:

    def test_evaluate(self):
        meta_info = {
            'pred_answer': 'dog',
            'gt_answer': 'dog',
        }
        data_sample = DataSample(metainfo=meta_info)
        data_samples = [data_sample for _ in range(10)]
        evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
        evaluator.process(data_samples)
        res = evaluator.evaluate(4)
        assert res['GQA/acc'] == 1.0

        meta_info = {
            'pred_answer': 'dog',
            'gt_answer': 'cat',
        }
        data_sample = DataSample(metainfo=meta_info)
        data_samples = [data_sample for _ in range(10)]
        evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
        evaluator.process(data_samples)
        res = evaluator.evaluate(4)
        assert res['GQA/acc'] == 0.0