preprocess_data.py 9.25 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

"""Preprocess data for Retro.

Stages (see argument '--retro-tasks'):
- Build chunk database (DB).
- Build index (train, add).
- Query pretraining neighbors.
"""

import json
import os
import sys
import torch

from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.retro.db import build_db
from megatron.core.datasets.retro.index import add_to_index, train_index
from megatron.core.datasets.retro.config import (
    RetroBertEmbedders,
    RetroGPTChunkDatasets,
    RetroPreprocessingConfig,
    RetroTokenizers,
)
from megatron.core.datasets.retro.query.gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
    MultiSplitGPTDataset,
    MultiSplitGPTDatasetConfig,
)
from megatron.core.datasets.retro.query.query import query_neighbors
from megatron.core.datasets.retro.query.utils import get_query_dir
from megatron.core.datasets.retro.utils import retro_makedir
from megatron.core.models.retro.utils import (
    get_config_path,
    get_gpt_data_dir,
)
from megatron.training import get_args, initialize_megatron, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.tokenizer.tokenizer import (
    _BertWordPieceTokenizer,
    _GPT2BPETokenizer,
    _GPTSentencePieceTokenizer,
)
from megatron.training import get_train_valid_test_num_samples
from pretrain_gpt import is_dataset_built_on_rank
from tools.bert_embedding import BertEmbedder, DiskDataParallelBertEmbedder
from tools.retro.config_utils import add_config_args


def add_retro_args(parser):
    group = parser.add_argument_group(title="Retro preprocessing")
    add_config_args(group, RetroPreprocessingConfig)
    return parser


def initialize_megatron_retro():
    '''Initialize megatron & save Retro config.'''

    # Prevent arguments.py from overriding preprocessing args.
    project_dir_idx = sys.argv.index("--retro-project-dir")
    retro_project_dir = sys.argv[project_dir_idx + 1]
    del sys.argv[project_dir_idx] # delete key
    del sys.argv[project_dir_idx] # delete value

    # Initialize.
    initialize_megatron(extra_args_provider=add_retro_args)

    args = get_args()
    args.retro_project_dir = retro_project_dir

    # Retro config.
    config = get_retro_preprocessing_config()

    # Save retro config.
    if config.retro_task_validate is None:
        retro_makedir(config, config.retro_project_dir)
        save_config(config)

    return config


def get_bert_embedders(config):
    mem_embedder = BertEmbedder(
        batch_size = config.retro_bert_batch_size,
        max_bert_seq_length = config.retro_bert_max_chunk_length,
        embedder_type = "megatron",
    )
    return RetroBertEmbedders(
        mem = mem_embedder,
        disk = DiskDataParallelBertEmbedder(mem_embedder, config.retro_block_size),
    )


def get_gpt_chunk_datasets(config):

    args = get_args()

    # Dataset config.
    data_dir = get_gpt_data_dir(config.retro_project_dir)
    blend = list(config.retro_gpt_data_path)
    for i in range(len(blend) - 1, -1, -2):
        blend[i] = os.path.join(data_dir, blend[i])
    data_config = MultiSplitGPTDatasetConfig(
        random_seed=config.retro_gpt_seed,
        sequence_length=config.retro_gpt_seq_length,
        blend=get_blend_from_list(blend),
        blend_per_split=[
            get_blend_from_list(args.train_data_path),
            get_blend_from_list(args.valid_data_path),
            get_blend_from_list(args.test_data_path)
        ],
        renormalize_blend_weights=args.renormalize_blend_weights,
        split=config.retro_gpt_split,
        split_preprocessing=config.retro_gpt_split,
        path_to_cache=config.retro_gpt_data_cache_path,
        return_document_ids=True,
        tokenizer=config.retro_tokenizers.gpt,
        reset_position_ids=args.reset_position_ids,
        reset_attention_mask=args.reset_attention_mask,
        eod_mask_loss=args.eod_mask_loss,
    )

    # GPT datasets.
    print_rank_0(" > multi-split gpt datasets.")
    train_valid_test_num_samples = get_train_valid_test_num_samples()
    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        MultiSplitGPTDataset,
        train_valid_test_num_samples,
        is_dataset_built_on_rank,
        data_config,
    ).build()

    gpt_datasets = {
        "train" : (train_ds, train_valid_test_num_samples[0]),
        "valid" : (valid_ds, train_valid_test_num_samples[1]),
        "test"  : (test_ds, train_valid_test_num_samples[2]),
    }

    # Chunk datasets.
    chunk_datasets = build_gpt_chunk_datasets_from_gpt_datasets(
        project_dir=config.retro_project_dir,
        gpt_datasets=gpt_datasets,
        sample_length=config.retro_gpt_seq_length,
        chunk_length=config.retro_gpt_chunk_length,
    )
    chunk_datasets = RetroGPTChunkDatasets(**chunk_datasets)

    return chunk_datasets


def get_gpt_tokenizer(config):
    '''GPT (BPE) tokenizer.'''
    tokenizer_type = config.retro_gpt_tokenizer_type
    if tokenizer_type == "GPT2BPETokenizer":
        assert config.retro_gpt_vocab_file and config.retro_gpt_merge_file
        return _GPT2BPETokenizer(
            vocab_file=os.path.join(
                config.retro_project_dir,
                config.retro_gpt_vocab_file,
            ),
            merge_file=os.path.join(
                config.retro_project_dir,
                config.retro_gpt_merge_file,
            ),
        )
    elif tokenizer_type == 'GPTSentencePieceTokenizer':
        assert config.retro_gpt_tokenizer_model is not None
        return _GPTSentencePieceTokenizer(os.path.join(
            config.retro_project_dir,
            config.retro_gpt_tokenizer_model,
        ))
    else:
        raise Exception("unrecognized gpt tokenizer, '%s'." % tokenizer_type)


def get_bert_tokenizer(config):
    '''Bert (Wordpiece) tokenizer.'''
    lower_case = {
        "BertWordPieceLowerCase" : True,
        "BertWordPieceCase" : False,
    }[config.retro_bert_tokenizer_type]
    return _BertWordPieceTokenizer(
        vocab_file=os.path.join(
            config.retro_project_dir,
            config.retro_bert_vocab_file,
        ),
        lower_case=lower_case,
    )


def get_tokenizers(config):
    return RetroTokenizers(
        gpt = get_gpt_tokenizer(config),
        bert = get_bert_tokenizer(config),
    )


def get_retro_preprocessing_config():

    # Arguments.
    args = get_args()

    # Retro config.
    config = core_transformer_config_from_args(
        args, config_class=RetroPreprocessingConfig)

    # Add tools.
    config.retro_tokenizers = get_tokenizers(config)
    config.retro_bert_embedders = get_bert_embedders(config)
    config.retro_gpt_chunk_datasets = get_gpt_chunk_datasets(config)

    return config


def save_config(config):
    '''Save copy of config within retro project dir.'''

    if torch.distributed.get_rank() == 0:

        # GPT config + block size.
        config_subset = {
            k:v for k,v in vars(config).items()
            if k.startswith("retro_gpt") and k != "retro_gpt_chunk_datasets"
        }
        config_subset["retro_block_size"] = config.retro_block_size

        # Bert config.
        config_subset["retro_bert_tokenizer_type"] = config.retro_bert_tokenizer_type
        config_subset["retro_bert_vocab_file"] = config.retro_bert_vocab_file

        # Neighbor directories.
        query_dir = get_query_dir(config.retro_project_dir)
        config_subset["retro_neighbor_dirs"] = {
            k : (os.path.relpath(v["neighbor_dir"], query_dir) if v is not None else None)
            for k, v in vars(config.retro_gpt_chunk_datasets).items()
        }

        # Save.
        config_path = get_config_path(config.retro_project_dir)
        with open(config_path, "w") as f:
            json.dump(config_subset, f, indent=4, sort_keys=True)

    torch.distributed.barrier()


if __name__ == "__main__":

    # Initalize Megatron.
    config = initialize_megatron_retro()

    # Expand tasks.
    task_remap = {
        "build" : [ "db-build", "index-train", "index-add", "query-neighbors" ],
        "index-build" : [ "index-train", "index-add" ],
        "db-build" : [ "db-build" ],
        "index-train" : [ "index-train" ],
        "index-add" : [ "index-add" ],
        "query-neighbors" : [ "query-neighbors" ],
    }
    tasks = []
    for task in config.retro_tasks:
        tasks.extend(task_remap[task])
    config.retro_tasks = tasks

    # Select task to run.
    for task in tasks:

        print_rank_0("start '%s%s'." % (
            "" if config.retro_task_validate is None else "[validate] ",
            task,
        ))

        # DB (i.e., chunk db).
        if task == "db-build":
            build_db(config)

        # Index.
        elif task == "index-train":
            train_index(config)
        elif task == "index-add":
            add_to_index(config)

        # Query.
        elif task == "query-neighbors":
            query_neighbors(config)

        else:
            raise Exception("specialize for task '%s'." % task)

        torch.distributed.barrier()

        print_rank_0("end '%s%s'." % (
            "" if config.retro_task_validate is None else "[validate] ",
            task,
        ))