evaluate_mbxp.py 4.4 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import json
import multiprocessing
import pickle
import queue
import re
import timeit

import pandas as pd
from tqdm import tqdm

from mxeval.execution import check_correctness as check_correctness_python
from mxeval.execution import (
    check_correctness_cpp,
    check_correctness_csharp,
    check_correctness_go,
    check_correctness_java,
    check_correctness_javascript,
    check_correctness_kotlin,
    check_correctness_perl,
    check_correctness_php,
    check_correctness_ruby,
    check_correctness_scala,
    check_correctness_swift,
    check_correctness_typescript,
)


def postprocess_golang(code: str) -> str:
    multi_line_imports = re.compile(
        r"^import \(\n(.+)((?:\n.+)+)\n\)", re.MULTILINE)
    line_imports = re.compile(r"^import \".*\"")
    func_main = re.compile(r"^func main.*^}", re.MULTILINE | re.DOTALL)

    code = code.replace("package main", "")  # Remove package main
    code = multi_line_imports.sub("", code)
    code = line_imports.sub("", code)
    code = func_main.sub("", code)

    return code


def postprocess_scala(code: str) -> str:
    code = code.replace("object Main extends App {", "")
    code = "".join(code.splitlines(True)[:-1])
    return code


def postprocess_python(code: str) -> str:
    return code.lstrip()


def worker(inp_queue, out_queue):
    while True:
        try:
            problem = inp_queue.get(timeout=5)
        except queue.Empty:
            break

        key = f"{problem['lang']}_{problem['entry_point']}"
        checker = eval(f"check_correctness_{problem['lang']}")

        problem["task_id"] = key
        problem["test"] = problem["test_code"]

        solution = problem["response"]

        try:
            solution = solution[: solution.index("```")]
        except ValueError:
            # Happens when a code block isn't closed properly
            pass

        if problem["lang"] == "go":
            solution = postprocess_golang(solution)
        elif problem["lang"] == "python":
            solution = postprocess_python(solution)
        elif problem["lang"] == "scala":
            solution = postprocess_scala(solution)

        # Mixtral likes escaping underscores for some reason, so let's remove
        # these
        solution = solution.replace("\\_", "_")

        # The evaluation script evaluates `code = prompt + solution + tests`
        # But Mixtral regenerates the prompt in its output, so we should remove
        # this
        problem["prompt"] = ""
        try:
            result = checker(problem, solution, timeout=20.0)
            out_queue.put(
                (
                    key,
                    problem["lang"],
                    result["passed"],
                    result["result"],
                    problem["response"],
                )
            )
        except Exception as e:
            print(e)
            out_queue.put(
                (key, problem["lang"], False, "", problem["response"]))


def evaluate_mbxp(results, n_workers):
    by_lang = {}
    for problem in results:
        by_lang.setdefault(problem["lang"], []).append(problem)

    inp_queue = multiprocessing.Queue()
    out_queue = multiprocessing.Queue()

    n_problems = 0

    for lang, problems in by_lang.items():
        if lang not in ["cpp", "python", "php",
                        "javascript", "ruby", "typescript"]:
            continue

        n_problems += len(problems)
        for problem in problems:
            inp_queue.put(problem)

    start = timeit.default_timer()
    workers = []
    for _ in range(n_workers):
        w = multiprocessing.Process(target=worker, args=(inp_queue, out_queue))
        w.start()
        workers.append(w)

    passes = {}
    n_passed = 0
    lang_passed = {}
    lang_counts = {}
    for i in tqdm(range(n_problems)):
        key, lang, passed, result, response = out_queue.get()
        passes[key] = {
            "passed": passed,
            "result": result,
            "response": response}
        n_passed += passed

        lang_passed.setdefault(lang, 0)
        lang_passed[lang] += passed

        lang_counts.setdefault(lang, 0)
        lang_counts[lang] += 1

    end = timeit.default_timer()
    print(f"Processed {n_problems} in {end - start}s")
    print(f"{100 * n_passed / n_problems : .02f}% pass@1")
    print(lang_passed, lang_counts)
    with open("evaluated_test.json", "w") as f:
        json.dump(passes, f, indent=2)

    return 100 * n_passed / n_problems