pretrain_gpt.py 13.4 KB
Newer Older
dongcl's avatar
dongcl committed
1
2
3
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
"""Pretrain GPT."""

4
5
6
7
import os, sys
current_dir = os.path.dirname(os.path.abspath(__file__))
megatron_path = os.path.join(current_dir, "Megatron-LM")
sys.path.append(megatron_path)
dongcl's avatar
dongcl committed
8

dongcl's avatar
dongcl committed
9
from functools import partial
dongcl's avatar
dongcl committed
10
from typing import List, Optional, Tuple, Union
dongcl's avatar
dongcl committed
11
12
13
14

import torch

from megatron.core import parallel_state
dongcl's avatar
dongcl committed
15
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
dongcl's avatar
dongcl committed
16
17
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.enums import ModelType
dongcl's avatar
dongcl committed
18
from megatron.core.models.gpt import GPTModel
dongcl's avatar
dongcl committed
19
20
21
22
23
24
25
26
27
28
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec,
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
    get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
    get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
dongcl's avatar
dongcl committed
29
from megatron.core.transformer.spec_utils import import_module
dongcl's avatar
dongcl committed
30
31
32
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
dongcl's avatar
dongcl committed
33
34
35
36
37
38
39
from megatron.training.utils import (
    get_batch_on_this_cp_rank,
    get_batch_on_this_tp_rank,
    get_blend_and_blend_per_split,
)
from megatron.training.yaml_arguments import core_transformer_config_from_yaml

dongcl's avatar
dongcl committed
40
41
42
43
44
45
46
47
48
49
50
51
52
import megatron.legacy.model  # isort: skip

# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import

try:
    from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
    from megatron.post_training.loss_func import loss_func as loss_func_modelopt
    from megatron.post_training.model_provider import model_provider as model_provider_modelopt

    has_nvidia_modelopt = True
except ImportError:
    has_nvidia_modelopt = False

dongcl's avatar
dongcl committed
53
54
55
56
57
from dcu_megatron import megatron_adaptor


stimer = StragglerDetector()

dongcl's avatar
dongcl committed
58
59
60
61

def model_provider(
    pre_process=True, post_process=True
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
dongcl's avatar
dongcl committed
62
63
64
65
66
67
68
69
70
71
72
73
74
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
dongcl's avatar
dongcl committed
75

dongcl's avatar
dongcl committed
76
77
78
    if has_nvidia_modelopt and modelopt_args_enabled(args):  # [ModelOpt]
        return model_provider_modelopt(pre_process, post_process)

dongcl's avatar
dongcl committed
79
80
81
    if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
        assert args.transformer_impl == "transformer_engine"

dongcl's avatar
dongcl committed
82
83
84
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
dongcl's avatar
dongcl committed
85
86
        torch.cuda.memory._record_memory_history(
            True,
dongcl's avatar
dongcl committed
87
88
89
            # keep 100,000 alloc/free events from before the snapshot
            trace_alloc_max_entries=100000,
            # record stack information for the trace events
dongcl's avatar
dongcl committed
90
91
            trace_alloc_record_context=True,
        )
dongcl's avatar
dongcl committed
92
93
94
95
96
97

        def oom_observer(device, alloc, device_alloc, device_free):
            # snapshot right after an OOM happened
            print('saving allocated state during OOM')
            snapshot = torch.cuda.memory._snapshot()
            from pickle import dump
dongcl's avatar
dongcl committed
98
99
100
101
102

            dump(
                snapshot,
                open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
            )
dongcl's avatar
dongcl committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        torch._C._cuda_attach_out_of_memory_observer(oom_observer)

    print_rank_0('building GPT model ...')
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
dongcl's avatar
dongcl committed
121
    else:  # using core models
dongcl's avatar
dongcl committed
122
123
124
125
126
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                # Define the decoder block spec
dongcl's avatar
dongcl committed
127
128
129
130
131
                transformer_layer_spec = get_gpt_decoder_block_spec(
                    config, use_transformer_engine=use_te, normalization=args.normalization
                )
            elif args.heterogeneous_layers_config_path is not None:
                transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
dongcl's avatar
dongcl committed
132
133
134
135
            else:
                # Define the decoder layer spec
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
dongcl's avatar
dongcl committed
136
137
138
139
140
141
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        args.moe_use_legacy_grouped_gemm,
                    )
dongcl's avatar
dongcl committed
142
143
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
dongcl's avatar
dongcl committed
144
145
146
147
148
149
150
151
152
153
154
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        args.moe_use_legacy_grouped_gemm,
                        normalization=args.normalization,
                    )
        mtp_block_spec = None
        if args.mtp_num_layers is not None:
            mtp_block_spec = get_gpt_mtp_block_spec(
                config, transformer_layer_spec, use_transformer_engine=use_te
dongcl's avatar
dongcl committed
155
            )
dongcl's avatar
dongcl committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling,
            mtp_block_spec=mtp_block_spec,
        )
dongcl's avatar
dongcl committed
173
174
175
176
177
178
179
180

    return model


def get_batch(data_iterator):
    """Generate a batch."""

    # TODO: this is pretty hacky, find a better way
dongcl's avatar
dongcl committed
181
182
183
    if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
        not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
    ):
dongcl's avatar
dongcl committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        return None, None, None, None, None

    # get batches based on the TP rank you are on
    batch = get_batch_on_this_tp_rank(data_iterator)

    # slice batch along sequence dimension for context parallelism
    batch = get_batch_on_this_cp_rank(batch)

    return batch.values()


# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10


dongcl's avatar
dongcl committed
199
200
201
def loss_func(
    loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
dongcl's avatar
dongcl committed
202
203
204
205
206
    """Loss function.

    Args:
        loss_mask (torch.Tensor): Used to mask out some portions of the loss
        output_tensor (torch.Tensor): The tensor with the losses
dongcl's avatar
dongcl committed
207
        model (GPTModel, optional): The model (can be wrapped)
dongcl's avatar
dongcl committed
208
209
210
211
212
213
214
215
216

    Returns:
        the loss scalar for this micro-batch
        the number of non-padded tokens in this microbatch
        a dict containing reporting metrics on the loss and number of tokens across
            the data parallel ranks
    """
    args = get_args()

dongcl's avatar
dongcl committed
217
218
219
    if has_nvidia_modelopt and modelopt_args_enabled(args):  # [ModelOpt]
        return loss_func_modelopt(loss_mask, output_tensor, model=model)

dongcl's avatar
dongcl committed
220
221
222
223
224
225
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    total_tokens = loss_mask.sum()
    loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

    if args.context_parallel_size > 1:
dongcl's avatar
dongcl committed
226
        torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
dongcl's avatar
dongcl committed
227
228
229
230
231
232
233
234

    # Check individual rank losses are not NaN prior to DP all-reduce.
    rerun_state_machine = get_rerun_state_machine()
    if args.check_for_nan_in_loss_and_grad:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=torch.isnan,
            message="found NaN in local forward loss calculation",
dongcl's avatar
dongcl committed
235
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
236
237
238
239
240
241
            fatal=True,
        )
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=torch.isinf,
            message="found Inf in local forward loss calculation",
dongcl's avatar
dongcl committed
242
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
243
244
245
246
247
248
249
250
251
252
253
254
            fatal=True,
        )
    # Check for spiky loss
    if args.check_for_spiky_loss:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=partial(
                rerun_state_machine.is_unexpectedly_large,
                threshold=SPIKY_LOSS_FACTOR,
                context="loss",
            ),
            message="Spiky loss",
dongcl's avatar
dongcl committed
255
            tolerance=0.0,  # forward pass calculations are determinisic
dongcl's avatar
dongcl committed
256
257
258
259
            fatal=False,
        )
    # Reduce loss for logging.
    reporting_loss = loss.clone().detach()
dongcl's avatar
dongcl committed
260
    torch.distributed.all_reduce(reporting_loss, group=parallel_state.get_data_parallel_group())
dongcl's avatar
dongcl committed
261

dongcl's avatar
dongcl committed
262
263
264
    # loss[0] is a view of loss, so it has ._base not None, which triggers assert error
    # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
    # on loss[0] fixes this
dongcl's avatar
dongcl committed
265
    local_num_tokens = loss[1].clone().detach().to(torch.int)
dongcl's avatar
dongcl committed
266
    return (loss[0].clone(), local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])})
dongcl's avatar
dongcl committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282


def forward_step(data_iterator, model: GPTModel):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
dongcl's avatar
dongcl committed
283
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
dongcl's avatar
dongcl committed
284
285
286
    timers('batch-generator').stop()

    with stimer:
dongcl's avatar
dongcl committed
287
288
289
290
291
292
        if args.use_legacy_models:
            output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
        else:
            output_tensor = model(
                tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
            )
dongcl's avatar
dongcl committed
293

dongcl's avatar
dongcl committed
294
295
    # [ModelOpt]: model is needed to access ModelOpt distillation losses
    return output_tensor, partial(loss_func, loss_mask, model=model)
dongcl's avatar
dongcl committed
296
297
298
299


def is_dataset_built_on_rank():
    return (
dongcl's avatar
dongcl committed
300
301
302
        parallel_state.is_pipeline_first_stage(ignore_virtual=True)
        or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
    ) and parallel_state.get_tensor_model_parallel_rank() == 0
dongcl's avatar
dongcl committed
303
304
305
306
307
308
309
310
311
312
313
314


def core_gpt_dataset_config_from_args(args):
    tokenizer = get_tokenizer()

    # Sometimes --data-path is too long, instead we parse it from a file.
    blend: Optional[Tuple[List[str], Optional[List[float]]]]
    blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
    blend, blend_per_split = get_blend_and_blend_per_split(args)

    return GPTDatasetConfig(
        random_seed=args.seed,
dongcl's avatar
dongcl committed
315
        sequence_length=args.seq_length,
dongcl's avatar
dongcl committed
316
317
318
319
320
321
322
323
324
325
326
        blend=blend,
        blend_per_split=blend_per_split,
        split=args.split,
        num_dataset_builder_threads=args.num_dataset_builder_threads,
        path_to_cache=args.data_cache_path,
        mmap_bin_files=args.mmap_bin_files,
        tokenizer=tokenizer,
        reset_position_ids=args.reset_position_ids,
        reset_attention_mask=args.reset_attention_mask,
        eod_mask_loss=args.eod_mask_loss,
        create_attention_mask=args.create_attention_mask_in_dataloader,
dongcl's avatar
dongcl committed
327
328
        object_storage_cache_path=args.object_storage_cache_path,
        mid_level_dataset_surplus=args.mid_level_dataset_surplus,
dongcl's avatar
dongcl committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    )


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
    args = get_args()

    config = core_gpt_dataset_config_from_args(args)

    if args.mock_data:
        dataset_type = MockGPTDataset
    else:
        dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dongcl's avatar
dongcl committed
350
        dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
dongcl's avatar
dongcl committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    ).build()

    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":

    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        ModelType.encoder_or_decoder,
        forward_step,
        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
dongcl's avatar
dongcl committed
369
        extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
dongcl's avatar
dongcl committed
370
    )