mmbench_evaluation_tricky.py 1.88 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
import json
import random

'''
This script provides metric calculation for mmbench_dev with the same accuarcy algo as OpenCompass server
'''

predictions = json.load(open('mmbench_dev_20230712.json'))

index2predictions = {}
for pred in predictions:
    index2predictions[pred['index']] = pred['prediction']


from collections import Counter

def most_common_elements(lst):
    counter = Counter(lst)
    max_count = max(counter.values())
    most_common = [element for element, count in counter.items() if count == max_count]
    return random.choice(most_common) # random sample from random choice

datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t')

glb_opts = ['A', 'B', 'C', 'D']
index2answer = {}
index2choices = {}
index2rawanswer = {}
for idx in range(len(datas)):
    data = datas.iloc[idx]
    
    choices = []
    for opt in glb_opts:
        if not pd.isna(data[opt]):
            choices.append(data[opt])
    index2choices[data['index']] = choices

    index2answer[data['index']] = glb_opts.index(data['answer'])
    index2rawanswer[data['index']] = choices[glb_opts.index(data['answer'])]

identity_indexes = list(set([int(_ % 1e6) for _ in index2predictions.keys()]))

correct = 0
total = 0
for index in identity_indexes:
    raw_preds = []
    raw_answer = []
    for _ in range(4):
        cycle_index = int(_ * 1e6 + index)
        if index2predictions.get(cycle_index, None) is not None:
            raw_answer = index2rawanswer[cycle_index]
            raw_pred = index2choices[cycle_index][index2predictions[cycle_index]]
            raw_preds.append(raw_pred)

    if len(set(raw_preds)) == 1:
        if raw_preds[0] == raw_answer:
            correct += 1
    else:
        result = most_common_elements(raw_preds)
        if result == raw_answer:
            correct += 1

    total += 1

print(correct, total, correct / total * 100.)