generate_prompt_dataset.py 912 Bytes
Newer Older
1
2
import argparse
import json
3
import random
4
5
6
7
8

random.seed(42)


def sample(args):
9
    with open(args.dataset_path, mode="r") as f:
10
11
        dataset_list = json.load(f)

12
13
14
15
    sampled_dataset = [
        {"instruction": sample["instruction"], "id": idx}
        for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
    ]
16

17
18
    with open(args.save_path, mode="w") as f:
        json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
19
20


21
if __name__ == "__main__":
22
    parser = argparse.ArgumentParser()
23
24
25
    parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
    parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
    parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
26
27
    args = parser.parse_args()
    sample(args)