errant.py 3.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import json
import os
import subprocess

import pandas as pd


def parse_args():
    """
    Parses arguments.
    Returns:
        Arguments containing the names of the prediction file and the file directory to for saving the evaluation results.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--fpath",
        type=str,
        help="path to a model output file in the lm-evaluation-harness format.",
    )
    parser.add_argument(
        "--out_fdir",
        type=str,
        help="path to an output directory for saving the results.",
    )
    args = parser.parse_args()
    return args


def read_examples(fpath: str):
    """
    Reads examples from the prediction file.
    Args:
        fpath: A path to the prediction file.
    Returns:
        Lists of the sources, targets, and predictions.
    """
    examples = pd.read_json(fpath, lines=True)
    sources, targets, predictions = [], [], []
    for i, example in examples.iterrows():
        sources.append(example["doc"]["source"])
        targets.append(example["doc"]["correction"])
        predictions.append(example["resps"][0][0].replace("\n\n", "\n"))
    return sources, targets, predictions


def save_results(fpath: str, obj: dict):
    """
    Saves the evaluation results.
    Args:
        fpath: A path for the output file for saving the results.
        obj: The evaluation results.
    """
    with open(fpath, "w+", encoding="utf-8") as out:
        json.dump(obj, out, indent=3)


def evaluate(fpath: str, out_fpath: str):
    """
    Runs the evaluation based on the ERRANT performance metric.
    Args:
        fpath: A path to the prediction file.
        out_Fpath: A path for the output file for saving the results.
    """
    tmp_name = fpath.replace(".jsonl", "").replace("/", "-")
    os.makedirs("tmp", exist_ok=True)
    sources, targets, predictions = read_examples(fpath=fpath)
    with open(f"tmp/{tmp_name}_sources.txt", "w+") as f:
        f.write("\n".join(sources))
    with open(f"tmp/{tmp_name}_targets.txt", "w+") as f:
        f.write("\n".join(targets))
    with open(f"tmp/{tmp_name}_predictions.txt", "w+") as f:
        f.write("\n".join(predictions))
    subprocess.run(
        f"errant_parallel -orig tmp/{tmp_name}_sources.txt -cor tmp/{tmp_name}_targets.txt -out tmp/{tmp_name}_targets.m2 -lev -tok",
        shell=True,
    )
    subprocess.run(
        f"errant_parallel -orig tmp/{tmp_name}_sources.txt -cor tmp/{tmp_name}_predictions.txt -out tmp/{tmp_name}_predictions.m2 -lev -tok",
        shell=True,
    )
    output = subprocess.check_output(
        f"errant_compare -ref tmp/{tmp_name}_targets.m2 -hyp tmp/{tmp_name}_predictions.m2",
        shell=True,
    )
    f_05 = float(output.decode().strip().split("\n")[-2].split()[-1].strip())
    print(f"Prediction fpath: {fpath}\n\nERRANT: {f_05}", flush=True)
    print(f"Saving to: {out_fpath}", flush=True)
    save_results(obj={"errant": f_05}, fpath=out_fpath)
    subprocess.run(f"rm tmp/{tmp_name}_*", shell=True)


def main():
    args = parse_args()
    fpath = args.fpath
    print(f"Out: {args.out_fdir}", flush=True)
    out_fpath = fpath.replace(".jsonl", "_errant.json")
    evaluate(fpath=fpath, out_fpath=out_fpath)


if __name__ == "__main__":
    print(
        "\nWARNING: make sure you have ERRANT installed to run the evaluation! Available here: https://github.com/chrisjbryant/errant\n\n",
        flush=True,
    )
    main()