aggregate_distributed_results.py 1.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from pathlib import Path

import fire


try:
    from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
except ImportError:
    from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file


def combine_partial_results(
    result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
):
    """Write first n lines of each file f in src_dir to dest_dir/f """
    src_dir = Path(result_dir)
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    paths_to_combine = list(src_dir.glob("rank*.json"))
    records = []
    for partial_result in paths_to_combine:
        records.extend(load_json(partial_result))
    preds = [x["pred"] for x in records]
    labels = [x["label"] for x in records]
    score_fn = calculate_bleu if calc_bleu else calculate_rouge
    metrics = score_fn(preds, labels)
    save_json(metrics, save_dir.joinpath("metrics.json"))  # better would be be {prefix}_{rouge|bleu}.json
    print(metrics)
    if just_metrics:
        return

    if save_prefix is None:
        save_prefix = "generated"
        print("using generated as prefix")

    tgt_path = save_dir.joinpath(f"{save_prefix}.target")
    write_txt_file(labels, tgt_path)
    pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
    write_txt_file(preds, pred_path)
    if "source" in records[0]:
        src_path = save_dir.joinpath(f"{save_prefix}.source")
        write_txt_file([x["source"] for x in records], src_path)


if __name__ == "__main__":
    fire.Fire(combine_partial_results)