pretrain_gpt.py 13.3 KB
Newer Older
dongcl's avatar
dongcl committed
1
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
2

dongcl's avatar
dongcl committed
3
"""Pretrain GPT."""
4
5

import datetime
6
import os
dongcl's avatar
dongcl committed
7
8
import torch

9
from functools import partial
10
from typing import List, Optional, Tuple, Union
11
from megatron.core import parallel_state
12
from megatron.training import get_args
13
from megatron.training import inprocess_restart
14
15
16
17
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
dongcl's avatar
dongcl committed
18
from megatron.core.enums import ModelType
19
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
20
21
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.enums import ModelType
22
from megatron.core.models.gpt import GPTModel
23
24
25
26
27
28
29
30
31
32
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec,
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
    get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
    get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
33
from megatron.core.transformer.spec_utils import import_module
34
35
36
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
dongcl's avatar
dongcl committed
37
38
39
40
41
42
from megatron.training.utils import (
    get_batch_on_this_cp_rank,
    get_batch_on_this_tp_rank,
    get_blend_and_blend_per_split,
)
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
dongcl's avatar
dongcl committed
43

44
import megatron.legacy.model  # isort: skip
dongcl's avatar
dongcl committed
45

46
47
48
49
50
51
52
53
54
55
56
57
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import

try:
    from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
    from megatron.post_training.loss_func import loss_func as loss_func_modelopt
    from megatron.post_training.model_provider import model_provider as model_provider_modelopt

    has_nvidia_modelopt = True
except ImportError:
    has_nvidia_modelopt = False

from dcu_megatron import megatron_adaptor
dongcl's avatar
dongcl committed
58
59
60

stimer = StragglerDetector()

61
62
63
64

def model_provider(
    pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
dongcl's avatar
dongcl committed
65
66
67
68
69
70
71
72
73
74
75
76
77
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
78
79
80
81

    if has_nvidia_modelopt and modelopt_args_enabled(args):  # [ModelOpt]
        return model_provider_modelopt(pre_process, post_process)

dongcl's avatar
dongcl committed
82
83
    if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
        assert args.transformer_impl == "transformer_engine"
dongcl's avatar
dongcl committed
84
85
86
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
87
88
        torch.cuda.memory._record_memory_history(
            True,
dongcl's avatar
dongcl committed
89
90
91
            # keep 100,000 alloc/free events from before the snapshot
            trace_alloc_max_entries=100000,
            # record stack information for the trace events
92
93
            trace_alloc_record_context=True,
        )
dongcl's avatar
dongcl committed
94
95
96
97
98
99

        def oom_observer(device, alloc, device_alloc, device_free):
            # snapshot right after an OOM happened
            print('saving allocated state during OOM')
            snapshot = torch.cuda.memory._snapshot()
            from pickle import dump
100
101
102
103
104

            dump(
                snapshot,
                open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
            )
dongcl's avatar
dongcl committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        torch._C._cuda_attach_out_of_memory_observer(oom_observer)

    print_rank_0('building GPT model ...')
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
123
    else:  # using core models
dongcl's avatar
dongcl committed
124
125
126
127
128
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                # Define the decoder block spec
129
130
131
132
133
                transformer_layer_spec = get_gpt_decoder_block_spec(
                    config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
                )
            elif args.heterogeneous_layers_config_path is not None:
                transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
dongcl's avatar
dongcl committed
134
135
136
137
            else:
                # Define the decoder layer spec
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
138
139
140
141
142
143
144
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        args.moe_use_legacy_grouped_gemm,
                        qk_l2_norm=args.qk_l2_norm
                    )
dongcl's avatar
dongcl committed
145
146
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
147
148
149
150
151
152
153
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        args.moe_use_legacy_grouped_gemm,
                        normalization=args.normalization,
                    )
dongcl's avatar
dongcl committed
154
155
        mtp_block_spec = None
        if args.mtp_num_layers is not None:
156
157
158
            mtp_block_spec = get_gpt_mtp_block_spec(
                config, transformer_layer_spec, use_transformer_engine=use_te, vp_stage=vp_stage
            )
dongcl's avatar
dongcl committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling,
            mtp_block_spec=mtp_block_spec,
175
            vp_stage=vp_stage,
dongcl's avatar
dongcl committed
176
        )
dongcl's avatar
dongcl committed
177
178
179
180
181
182

    return model


def get_batch(data_iterator):
    """Generate a batch."""
183

dongcl's avatar
dongcl committed
184
    # TODO: this is pretty hacky, find a better way
185
186
187
    if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
        not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
    ):
dongcl's avatar
dongcl committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        return None, None, None, None, None

    # get batches based on the TP rank you are on
    batch = get_batch_on_this_tp_rank(data_iterator)

    # slice batch along sequence dimension for context parallelism
    batch = get_batch_on_this_cp_rank(batch)

    return batch.values()


# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10


203
204
205
def loss_func(
    loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
dongcl's avatar
dongcl committed
206
207
208
209
210
    """Loss function.

    Args:
        loss_mask (torch.Tensor): Used to mask out some portions of the loss
        output_tensor (torch.Tensor): The tensor with the losses
211
        model (GPTModel, optional): The model (can be wrapped)
dongcl's avatar
dongcl committed
212
213
214
215
216
217
218
219
220

    Returns:
        the loss scalar for this micro-batch
        the number of non-padded tokens in this microbatch
        a dict containing reporting metrics on the loss and number of tokens across
            the data parallel ranks
    """
    args = get_args()

221
222
    if has_nvidia_modelopt and modelopt_args_enabled(args):  # [ModelOpt]
        return loss_func_modelopt(loss_mask, output_tensor, model=model)
dongcl's avatar
dongcl committed
223

224
225
226
    losses = output_tensor.view(-1).float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses * loss_mask)
dongcl's avatar
dongcl committed
227
228
229
230
231

    # Check individual rank losses are not NaN prior to DP all-reduce.
    rerun_state_machine = get_rerun_state_machine()
    if args.check_for_nan_in_loss_and_grad:
        rerun_state_machine.validate_result(
232
            result=loss,
dongcl's avatar
dongcl committed
233
234
            rejection_func=torch.isnan,
            message="found NaN in local forward loss calculation",
235
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
236
237
238
            fatal=True,
        )
        rerun_state_machine.validate_result(
239
            result=loss,
dongcl's avatar
dongcl committed
240
241
            rejection_func=torch.isinf,
            message="found Inf in local forward loss calculation",
242
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
243
244
245
246
247
            fatal=True,
        )
    # Check for spiky loss
    if args.check_for_spiky_loss:
        rerun_state_machine.validate_result(
248
            result=loss,
dongcl's avatar
dongcl committed
249
250
251
252
253
254
            rejection_func=partial(
                rerun_state_machine.is_unexpectedly_large,
                threshold=SPIKY_LOSS_FACTOR,
                context="loss",
            ),
            message="Spiky loss",
255
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
256
257
            fatal=False,
        )
258
259
260
261
262

    num_tokens = loss_mask.sum().clone().detach().to(torch.int)
    reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])

    return (loss, num_tokens, {'lm loss': reporting_loss})
dongcl's avatar
dongcl committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278


def forward_step(data_iterator, model: GPTModel):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
279
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
dongcl's avatar
dongcl committed
280
281
282
    timers('batch-generator').stop()

    with stimer:
dongcl's avatar
dongcl committed
283
        if args.use_legacy_models:
284
            output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
dongcl's avatar
dongcl committed
285
        else:
286
287
288
            output_tensor = model(
                tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
            )
dongcl's avatar
dongcl committed
289

290
291
    # [ModelOpt]: model is needed to access ModelOpt distillation losses
    return output_tensor, partial(loss_func, loss_mask, model=model)
dongcl's avatar
dongcl committed
292
293
294
295


def is_dataset_built_on_rank():
    return (
296
297
298
        parallel_state.is_pipeline_first_stage(ignore_virtual=True)
        or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
    ) and parallel_state.get_tensor_model_parallel_rank() == 0
dongcl's avatar
dongcl committed
299
300
301
302
303
304
305
306
307
308
309
310


def core_gpt_dataset_config_from_args(args):
    tokenizer = get_tokenizer()

    # Sometimes --data-path is too long, instead we parse it from a file.
    blend: Optional[Tuple[List[str], Optional[List[float]]]]
    blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
    blend, blend_per_split = get_blend_and_blend_per_split(args)

    return GPTDatasetConfig(
        random_seed=args.seed,
dongcl's avatar
dongcl committed
311
        sequence_length=args.seq_length,
dongcl's avatar
dongcl committed
312
313
314
315
316
317
318
319
320
321
322
        blend=blend,
        blend_per_split=blend_per_split,
        split=args.split,
        num_dataset_builder_threads=args.num_dataset_builder_threads,
        path_to_cache=args.data_cache_path,
        mmap_bin_files=args.mmap_bin_files,
        tokenizer=tokenizer,
        reset_position_ids=args.reset_position_ids,
        reset_attention_mask=args.reset_attention_mask,
        eod_mask_loss=args.eod_mask_loss,
        create_attention_mask=args.create_attention_mask_in_dataloader,
323
324
        object_storage_cache_path=args.object_storage_cache_path,
        mid_level_dataset_surplus=args.mid_level_dataset_surplus,
dongcl's avatar
dongcl committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    )


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
    args = get_args()

    config = core_gpt_dataset_config_from_args(args)

    if args.mock_data:
        dataset_type = MockGPTDataset
    else:
        dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
346
        dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
dongcl's avatar
dongcl committed
347
348
349
350
351
352
353
354
355
356
357
358
    ).build()

    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":

    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

359
360
361
    # Optionally enable inprocess restart on pretrain
    pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)

dongcl's avatar
dongcl committed
362
363
364
365
366
367
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        ModelType.encoder_or_decoder,
        forward_step,
        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
368
369
        extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
        store=store,
dongcl's avatar
dongcl committed
370
    )