evaluate_mmmu.py 3.58 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
import argparse
import glob
import json
import os
silencealiang's avatar
add  
silencealiang committed
5
import sys
xingjinliang's avatar
xingjinliang committed
6
7
8
import re
import subprocess

silencealiang's avatar
add  
silencealiang committed
9
10
11
12
13
# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)

xingjinliang's avatar
xingjinliang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from run_text_generation import get_output_path
from config import EvaluationConfig


def get_input_output_paths(input_path, task):
    """Get all input files and an output path for a merged file."""
    # Single input file.
    if os.path.exists(input_path):
        input_file_paths = [input_path]
        output_file_path = input_path.replace(".jsonl", "-merged.json")
    # Select multiple partitions and dp ranks.
    else:
        cfg = EvaluationConfig(task=task, output_path=input_path, partition_id="*")
        pattern = get_output_path(cfg, dp_rank="*")
        input_file_paths = glob.glob(pattern)

        output_file_path = input_path + f"-{task}-merged.json"

    return input_file_paths, output_file_path


def convert_to_mmmu_format(input_path):
    """Convert input files to MMMU compatible format."""
    input_file_paths, output_file_path = get_input_output_paths(input_path, "MMMU")

    output = dict()

    for input_file_path in input_file_paths:
        with open(input_file_path, "r") as input_file:
            for line in input_file:
                res = json.loads(line)

                sample_id = res["sample_id"]
                prediction = res["prediction"]

                if res["question_type"] == "multiple-choice":
                    from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response

                    prediction = parse_multi_choice_response(
                        prediction, res["all_choices"], res["index2ans"]
                    )

                # MMMU eval script expects just a sample_id to prediction mapping.
silencealiang's avatar
add  
silencealiang committed
57
58
59
60
                # Skip possible duplicates.
                if sample_id in output:
                    continue

xingjinliang's avatar
xingjinliang committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
                output[sample_id] = prediction

    with open(output_file_path, "w") as output_file:
        json.dump(output, output_file)

    return output_file_path


def mmmu_eval(input_path, groundtruth_path):
    """Run MMMU evaluation."""
    result_file = convert_to_mmmu_format(input_path)

    # The MMMU repo has a script for running the actual evaluation but no API. So launching the script here.
    output = subprocess.run(
        [
            "python",
            "examples/multimodal/MMMU/mmmu/main_eval_only.py",
            "--output_path",
            result_file,
            "--answer_path",
            groundtruth_path,
        ],
        capture_output=True,
        text=True,
    )

    print(output.stderr)
    print(output.stdout)

    m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout)

    return float(m.group(1)) * 100.0


def main():
    """Run MMMU evaluation."""
    # Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here.
    default_groundtruth_path = "examples/multimodal/MMMU/mmmu/answer_dict_val.json"

    parser = argparse.ArgumentParser()
    parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)")
    parser.add_argument(
        "--groundtruth-path",
        type=str,
        default=default_groundtruth_path,
        help="Path to groundtruth file. Defaults to the validation file in the MMMU repo.",
    )
    args = parser.parse_args()

    avg_acc = mmmu_eval(args.input_path, args.groundtruth_path)

    print(f"MMMU average accuracy: {avg_acc:.2f}")


if __name__ == "__main__":
    main()