prepare_sft_dataset.py 5.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare sft dataset for fine-tuning
"""

import argparse
import json
import math
import os
from multiprocessing import cpu_count

from colossal_llama2.dataset.conversation import default_conversation
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer

from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_input_dirs",
        type=str,
        required=True,
        default=None,
        help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
    )
    parser.add_argument(
        "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
    )
    parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
    parser.add_argument(
        "--data_jsonl_output_dir",
        type=str,
        default="jsonl_output",
        help="Output directory of spliced dataset with jsonl format",
    )
    parser.add_argument(
        "--data_arrow_output_dir",
        type=str,
        default="arrow_output",
        help="Output directory of spliced dataset with arrow format",
    )
    parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
    parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
    args = parser.parse_args()

    if args.num_spliced_dataset_bins >= 100000:
        raise ValueError("Too many spliced divisions, must be smaller than 100000")

    assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
    assert not os.path.exists(
        args.data_jsonl_output_dir
    ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
    assert not os.path.exists(
        args.data_arrow_output_dir
    ), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
    os.makedirs(args.data_jsonl_output_dir)
    os.makedirs(args.data_arrow_output_dir)

    # Prepare to all input datasets
    input_data_paths = []
    input_data_dirs = args.data_input_dirs.split(",")
    for ds_dir in input_data_dirs:
        ds_dir = os.path.abspath(ds_dir)
        assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
        ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
        ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
        input_data_paths.extend(ds_paths)

    # Prepare to data splitting.
    train_splits = []
    split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
    for i in range(0, 100, split_interval):
        start = i
        end = i + split_interval
        if end > 100:
            end = 100
        train_splits.append(f"train[{start}%:{end}%]")

    # Prepare to the tokenizer.
    tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
    tokenizer.add_bos_token = False
    tokenizer.add_eos_token = False
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token

    list_dataset = load_dataset(
        path="json",
        data_files=input_data_paths,
        cache_dir=os.path.join(args.data_cache_dir, "raw"),
        keep_in_memory=False,
        split=train_splits,
        num_proc=cpu_count(),
    )
    for index, dataset in enumerate(list_dataset):
        assert isinstance(dataset, dataset_dict.Dataset)
        logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
        dataset = dataset.map(
            function=supervised_tokenize_sft,
            fn_kwargs={
                "tokenizer": tokenizer,
                "conversation_template": default_conversation,
                "max_length": args.max_length,
            },
            keep_in_memory=False,
            num_proc=min(len(dataset), cpu_count()),
        )

        dataset = dataset.filter(lambda data: data["labels"] is not None)
        dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)

        # We don't concatenate data samples here.
        spliced_dataset = dataset
        # Save each jsonl spliced dataset.
        output_index = "0" * (5 - len(str(index))) + str(index)
        output_name = f"part-{output_index}"
        output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
        # st = time.time()
        with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
            spliced_count = 0
            for spliced_data_point in spliced_dataset:
                if spliced_count % 500 == 0:
                    logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
                spliced_count += 1
                fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")

        # Save each arrow spliced dataset
        output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
        logger.info(f"Start to save {output_arrow_path}")
        spliced_dataset = load_dataset(
            path="json",
            data_files=[output_jsonl_path],
            cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
            keep_in_memory=False,
            num_proc=cpu_count(),
            split="train",
        )
        spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))


if __name__ == "__main__":
    main()