pretrain_gpt.py 12.1 KB
Newer Older
dongcl's avatar
dongcl committed
1
2
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
"""Pretrain GPT."""
3
import os
dongcl's avatar
dongcl committed
4
import torch
5
6
7
from functools import partial
from contextlib import nullcontext
import inspect
dongcl's avatar
dongcl committed
8

9
10
11
12
13
14
from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
dongcl's avatar
dongcl committed
15
from megatron.core.enums import ModelType
16
17
18
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
dongcl's avatar
dongcl committed
19
20
21
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
    get_gpt_heterogeneous_layer_spec,
)
dongcl's avatar
dongcl committed
22
from megatron.core.rerun_state_machine import get_rerun_state_machine
23
24
25
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
dongcl's avatar
dongcl committed
26
from megatron.core.utils import StragglerDetector
27
from megatron.core.transformer.spec_utils import import_module
dongcl's avatar
dongcl committed
28
29
30
31
32
from megatron.training.utils import (
    get_batch_on_this_cp_rank,
    get_batch_on_this_tp_rank,
    get_blend_and_blend_per_split,
)
33
from megatron.training.arguments import core_transformer_config_from_args
dongcl's avatar
dongcl committed
34
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
35
36
37
38
39
40
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec,
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
    get_gpt_mtp_block_spec,
)
dongcl's avatar
dongcl committed
41
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
dongcl's avatar
dongcl committed
42

dongcl's avatar
dongcl committed
43
44
45
46
47
from dcu_megatron import megatron_adaptor


stimer = StragglerDetector()

48
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
dongcl's avatar
dongcl committed
49
50
51
52
53
54
55
56
57
58
59
60
61
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
dongcl's avatar
dongcl committed
62
63
    if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
        assert args.transformer_impl == "transformer_engine"
dongcl's avatar
dongcl committed
64
65
66
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
67
        torch.cuda.memory._record_memory_history(True,
dongcl's avatar
dongcl committed
68
69
            # keep 100,000 alloc/free events from before the snapshot
            trace_alloc_max_entries=100000,
70

dongcl's avatar
dongcl committed
71
            # record stack information for the trace events
72
            trace_alloc_record_context=True)
dongcl's avatar
dongcl committed
73
74
75
76
77
78

        def oom_observer(device, alloc, device_alloc, device_free):
            # snapshot right after an OOM happened
            print('saving allocated state during OOM')
            snapshot = torch.cuda.memory._snapshot()
            from pickle import dump
79
            dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
dongcl's avatar
dongcl committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

        torch._C._cuda_attach_out_of_memory_observer(oom_observer)

    print_rank_0('building GPT model ...')
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
98
    else: # using core models
dongcl's avatar
dongcl committed
99
100
101
102
103
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                # Define the decoder block spec
104
                transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
dongcl's avatar
dongcl committed
105
106
            elif args.heterogeneous_layers_config_path is not None:
                transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
dongcl's avatar
dongcl committed
107
108
109
110
            else:
                # Define the decoder layer spec
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
111
112
                        args.num_experts, args.moe_grouped_gemm,
                        args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
dongcl's avatar
dongcl committed
113
114
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
115
116
117
                        args.num_experts, args.moe_grouped_gemm,
                        args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm,
                        normalization=args.normalization)
dongcl's avatar
dongcl committed
118
119
        mtp_block_spec = None
        if args.mtp_num_layers is not None:
120
            mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
dongcl's avatar
dongcl committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling,
            mtp_block_spec=mtp_block_spec,
        )
138
        print_rank_0(model)
dongcl's avatar
dongcl committed
139
140
141
142
143
144
145
146

    return model


def get_batch(data_iterator):
    """Generate a batch."""

    # TODO: this is pretty hacky, find a better way
147
    if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
dongcl's avatar
dongcl committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        return None, None, None, None, None

    # get batches based on the TP rank you are on
    batch = get_batch_on_this_tp_rank(data_iterator)

    # slice batch along sequence dimension for context parallelism
    batch = get_batch_on_this_cp_rank(batch)

    return batch.values()


# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10


163
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
dongcl's avatar
dongcl committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    """Loss function.

    Args:
        loss_mask (torch.Tensor): Used to mask out some portions of the loss
        output_tensor (torch.Tensor): The tensor with the losses

    Returns:
        the loss scalar for this micro-batch
        the number of non-padded tokens in this microbatch
        a dict containing reporting metrics on the loss and number of tokens across
            the data parallel ranks
    """
    args = get_args()

    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    total_tokens = loss_mask.sum()
    loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

    if args.context_parallel_size > 1:
184
        torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
dongcl's avatar
dongcl committed
185
186
187
188
189
190
191
192

    # Check individual rank losses are not NaN prior to DP all-reduce.
    rerun_state_machine = get_rerun_state_machine()
    if args.check_for_nan_in_loss_and_grad:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=torch.isnan,
            message="found NaN in local forward loss calculation",
193
            tolerance=0.0,        # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
194
195
196
197
198
199
            fatal=True,
        )
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=torch.isinf,
            message="found Inf in local forward loss calculation",
200
            tolerance=0.0,        # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
201
202
203
204
205
206
207
208
209
210
211
212
            fatal=True,
        )
    # Check for spiky loss
    if args.check_for_spiky_loss:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=partial(
                rerun_state_machine.is_unexpectedly_large,
                threshold=SPIKY_LOSS_FACTOR,
                context="loss",
            ),
            message="Spiky loss",
213
            tolerance=0.0,        # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
214
215
216
217
            fatal=False,
        )
    # Reduce loss for logging.
    reporting_loss = loss.clone().detach()
218
    torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
dongcl's avatar
dongcl committed
219

dongcl's avatar
dongcl committed
220
221
222
    # loss[0] is a view of loss, so it has ._base not None, which triggers assert error
    # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
    # on loss[0] fixes this
dongcl's avatar
dongcl committed
223
    local_num_tokens = loss[1].clone().detach().to(torch.int)
224
225
226
227
228
    return (
        loss[0].clone(),
        local_num_tokens,
        {'lm loss': (reporting_loss[0], reporting_loss[1])},
    )
dongcl's avatar
dongcl committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244


def forward_step(data_iterator, model: GPTModel):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
245
246
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
            data_iterator)
dongcl's avatar
dongcl committed
247
248
249
    timers('batch-generator').stop()

    with stimer:
dongcl's avatar
dongcl committed
250
        if args.use_legacy_models:
251
252
            output_tensor = model(tokens, position_ids, attention_mask,
                                labels=labels)
dongcl's avatar
dongcl committed
253
        else:
254
255
            output_tensor = model(tokens, position_ids, attention_mask,
                                labels=labels, loss_mask=loss_mask)
dongcl's avatar
dongcl committed
256

257
    return output_tensor, partial(loss_func, loss_mask)
dongcl's avatar
dongcl committed
258
259
260
261


def is_dataset_built_on_rank():
    return (
262
263
        mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
    ) and mpu.get_tensor_model_parallel_rank() == 0
dongcl's avatar
dongcl committed
264
265
266
267
268
269
270
271
272
273
274
275


def core_gpt_dataset_config_from_args(args):
    tokenizer = get_tokenizer()

    # Sometimes --data-path is too long, instead we parse it from a file.
    blend: Optional[Tuple[List[str], Optional[List[float]]]]
    blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
    blend, blend_per_split = get_blend_and_blend_per_split(args)

    return GPTDatasetConfig(
        random_seed=args.seed,
dongcl's avatar
dongcl committed
276
        sequence_length=args.seq_length,
dongcl's avatar
dongcl committed
277
278
279
280
281
282
283
284
285
286
287
        blend=blend,
        blend_per_split=blend_per_split,
        split=args.split,
        num_dataset_builder_threads=args.num_dataset_builder_threads,
        path_to_cache=args.data_cache_path,
        mmap_bin_files=args.mmap_bin_files,
        tokenizer=tokenizer,
        reset_position_ids=args.reset_position_ids,
        reset_attention_mask=args.reset_attention_mask,
        eod_mask_loss=args.eod_mask_loss,
        create_attention_mask=args.create_attention_mask_in_dataloader,
288
        s3_cache_path=args.s3_cache_path,
dongcl's avatar
dongcl committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    )


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
    args = get_args()

    config = core_gpt_dataset_config_from_args(args)

    if args.mock_data:
        dataset_type = MockGPTDataset
    else:
        dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
310
311
312
313
        dataset_type,
        train_val_test_num_samples,
        is_dataset_built_on_rank,
        config
dongcl's avatar
dongcl committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    ).build()

    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":

    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        ModelType.encoder_or_decoder,
        forward_step,
        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
    )