merge.py 749 Bytes
Newer Older
Yuanchen's avatar
Yuanchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import argparse
import os

from utils import jload, jdump


def generate(args):
    dataset = []
    for i in range(args.shards):
        shard = jload(os.path.join(args.answer_path,
                      f'{args.model_name}_answers_rank{i}.json'))
        dataset.extend(shard)

    dataset.sort(key=lambda x: x['id'])
    jdump(dataset, os.path.join(args.answer_path,
                                f'{args.model_name}_answers.json'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='model')
    parser.add_argument('--shards', type=int, default=4)
    parser.add_argument('--answer_path', type=str, default="answer")
    args = parser.parse_args()
    generate(args)