evaluate_mmmu.py 4.13 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import glob
import json
import os
import sys
import re
import subprocess

from .mmmu_utils import parse_multi_choice_response
# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)

from run_text_generation import get_output_path
from config import EvaluationConfig



def get_input_output_paths(input_path, task):
    """Get all input files and an output path for a merged file."""
    # Single input file.
    if os.path.exists(input_path):
        input_file_paths = [input_path]
        output_file_path = input_path.replace(".jsonl", "-merged.json")
    # Select multiple partitions and dp ranks.
    else:
        cfg = EvaluationConfig(task=task, output_path=input_path, partition_id="*")
        pattern = get_output_path(cfg, dp_rank="*")
        input_file_paths = glob.glob(pattern)

        output_file_path = input_path + f"-{task}-merged.json"

    return input_file_paths, output_file_path


def extract_answer(text):
    import re
    # Regular expression to find content inside \answer{xxx}
    match = re.search(r'\\answer\{(.*?)\}', text)
    if match:
        return match.group(1)  # Return the content inside the braces

    # Regular expression to find content inside \boxed{xxx}
    match = re.search(r'\\boxed\{(.*?)\}', text)
    if match:
        return match.group(1)  # Return the content inside the braces

    text = text.replace("Answer:", "Answer: ")
    return text  # Return the original string if no match is found



def convert_to_mmmu_format(input_path):
    """Convert input files to MMMU compatible format."""
    input_file_paths, output_file_path = get_input_output_paths(input_path, "MMMU")

    output = dict()

    for input_file_path in input_file_paths:
        with open(input_file_path, "r") as input_file:
            for line in input_file:
                res = json.loads(line)

                sample_id = res["sample_id"]
                prediction = res["prediction"]

                if sample_id in output:
                    continue

                if res["question_type"] == "multiple-choice":
                    prediction = extract_answer(prediction)
                    prediction = parse_multi_choice_response(
                        prediction, res["all_choices"], res["index2ans"]
                    )

                # MMMU eval script expects just a sample_id to prediction mapping.
                output[sample_id] = prediction

    with open(output_file_path, "w") as output_file:
        json.dump(output, output_file, indent=4, sort_keys=True)

    return output_file_path


def mmmu_eval(input_path, groundtruth_path):
    """Run MMMU evaluation."""
    result_file = convert_to_mmmu_format(input_path)

    # The MMMU repo has a script for running the actual evaluation but no API. So launching the script here.
    output = subprocess.run(
        [
            "python",
            "examples/multimodal/MMMU/mmmu/main_eval_only.py",
            "--output_path",
            result_file,
            "--answer_path",
            groundtruth_path,
        ],
        capture_output=True,
        text=True,
    )

    print(output.stderr)
    print(output.stdout)

    m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout)

    return float(m.group(1)) * 100.0


def main():
    """Run MMMU evaluation."""
    # Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here.
    default_groundtruth_path = "examples/multimodal/MMMU/mmmu/answer_dict_val.json"

    parser = argparse.ArgumentParser()
    parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)")
    parser.add_argument(
        "--groundtruth-path",
        type=str,
        default=default_groundtruth_path,
        help="Path to groundtruth file. Defaults to the validation file in the MMMU repo.",
    )
    args = parser.parse_args()

    avg_acc = mmmu_eval(args.input_path, args.groundtruth_path)

    print(f"MMMU average accuracy: {avg_acc:.2f}")


if __name__ == "__main__":
    main()