pretrain_gpt.py 12.2 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
xingjinliang's avatar
xingjinliang committed
2
"""Pretrain GPT."""
3

xingjinliang's avatar
xingjinliang committed
4
import os
5
import torch
6
from functools import partial
xingjinliang's avatar
xingjinliang committed
7
8
9
10
11
12
13
14
15
from contextlib import nullcontext
import inspect

from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
16
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
17
18
19
20
21
22
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
Mohammad's avatar
Mohammad committed
23
from megatron.training import pretrain
xingjinliang's avatar
xingjinliang committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
    get_batch_on_this_cp_rank,
    get_batch_on_this_tp_rank,
    get_blend_and_blend_per_split,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec,
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
)
38
39
40

from megatron.core.transformer.mtp.mtp_spec import get_mtp_spec

xingjinliang's avatar
xingjinliang committed
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch._dynamo
torch._dynamo.config.suppress_errors = True

stimer = StragglerDetector()

def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Mohammad's avatar
Mohammad committed
54

xingjinliang's avatar
xingjinliang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
        torch.cuda.memory._record_memory_history(True,
            # keep 100,000 alloc/free events from before the snapshot
            trace_alloc_max_entries=100000,

            # record stack information for the trace events
            trace_alloc_record_context=True)
69

70
    print_rank_0('building GPT model ...')
xingjinliang's avatar
xingjinliang committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
    else: # using core models
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                # Define the decoder block spec
                transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te)
            else:
                # Define the decoder layer spec
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                        args.num_experts, args.moe_grouped_gemm,
xingjinliang's avatar
xingjinliang committed
97
                        args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
xingjinliang's avatar
xingjinliang committed
98
99
100
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
                        args.num_experts, args.moe_grouped_gemm,
xingjinliang's avatar
xingjinliang committed
101
                        args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
xingjinliang's avatar
xingjinliang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        build_model_context = nullcontext
        build_model_context_args = {}
        if args.fp8_param_gather:
            try:
                from transformer_engine.pytorch import fp8_model_init

                build_model_context = fp8_model_init
                build_model_context_args["enabled"] = True

                # Check if fp8_model_init supports preserve_high_precision_init_val
                if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
                    build_model_context_args["preserve_high_precision_init_val"] = True
            except:
                raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")

118
119
120
121
122
123
124
125
126
127
128
        # Define the decoder layer spec
        if use_te:
            mtp_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                args.num_experts, args.moe_grouped_gemm,
                args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
        else:
            mtp_transformer_layer_spec = get_gpt_layer_local_spec(
                args.num_experts, args.moe_grouped_gemm,
                args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)

        mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
xingjinliang's avatar
xingjinliang committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        with build_model_context(**build_model_context_args):
            model = GPTModel(
                config=config,
                transformer_layer_spec=transformer_layer_spec,
                vocab_size=args.padded_vocab_size,
                max_sequence_length=args.max_position_embeddings,
                pre_process=pre_process,
                post_process=post_process,
                fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
                parallel_output=True,
                share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
                position_embedding_type=args.position_embedding_type,
                rotary_percent=args.rotary_percent,
                rotary_base=args.rotary_base,
143
144
145
146
147
148
149
                rope_scaling=args.use_rope_scaling,
                mtp_spec=mtp_spec,
                num_nextn_predict_layers=args.num_nextn_predict_layers,
                share_mtp_embedding_and_output_weight=args.share_mtp_embedding_and_output_weight,
                recompute_mtp_norm=args.recompute_mtp_norm,
                recompute_mtp_layer=args.recompute_mtp_layer,
                mtp_loss_scale=args.mtp_loss_scale
xingjinliang's avatar
xingjinliang committed
150
            )
151
    model = torch.compile(model,mode='max-autotune-no-cudagraphs')
xingjinliang's avatar
xingjinliang committed
152
    print_rank_0(model)
xingjinliang's avatar
xingjinliang committed
153

154
155
156
    return model


Mohammad's avatar
Mohammad committed
157
def get_batch(data_iterator):
xingjinliang's avatar
xingjinliang committed
158
    """Generate a batch."""
159

xingjinliang's avatar
xingjinliang committed
160
161
162
    # TODO: this is pretty hacky, find a better way
    if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
        return None, None, None, None, None
163

xingjinliang's avatar
xingjinliang committed
164
165
166
167
168
169
170
    # get batches based on the TP rank you are on
    batch = get_batch_on_this_tp_rank(data_iterator)

    # slice batch along sequence dimension for context parallelism
    batch = get_batch_on_this_cp_rank(batch)

    return batch.values()
171
172


xingjinliang's avatar
xingjinliang committed
173
174
# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2
175
176


xingjinliang's avatar
xingjinliang committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
    """Loss function.

    Args:
        loss_mask (torch.Tensor): Used to mask out some portions of the loss
        output_tensor (torch.Tensor): The tensor with the losses

    Returns:
        the loss scalar for this micro-batch
        the number of non-padded tokens in this microbatch
        a dict containing reporting metrics on the loss and number of tokens across
            the data parallel ranks
    """
    args = get_args()

192
193
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
xingjinliang's avatar
xingjinliang committed
194
195
196
197
198
    total_tokens = loss_mask.sum()
    loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

    if args.context_parallel_size > 1:
        torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
199

xingjinliang's avatar
xingjinliang committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    # Check individual rank losses are not NaN prior to DP all-reduce.
    rerun_state_machine = get_rerun_state_machine()
    if args.check_for_nan_in_loss_and_grad:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=torch.isnan,
            message="found NaN in local forward loss calculation",
            tolerance=0.0,        # forward pass calculations are determinisic
            fatal=True,
        )
    # Check for spiky loss
    if args.check_for_spiky_loss:
        rerun_state_machine.validate_result(
            result=loss[0],
            rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC),
            message="Spiky loss",
            tolerance=0.0,        # forward pass calculations are determinisic
            fatal=False,
        )
219
    # Reduce loss for logging.
xingjinliang's avatar
xingjinliang committed
220
221
222
223
224
225
226
227
228
    reporting_loss = loss.clone().detach()
    torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())

    local_num_tokens = loss[1].clone().detach().to(torch.int)
    return (
        loss[0] * args.context_parallel_size,
        local_num_tokens,
        {'lm loss': (reporting_loss[0], reporting_loss[1])},
    )
229
230


xingjinliang's avatar
xingjinliang committed
231
232
def forward_step(data_iterator, model: GPTModel):
    """Forward training step.
233

xingjinliang's avatar
xingjinliang committed
234
235
236
237
    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
    """
238
    args = get_args()
Mohammad's avatar
Mohammad committed
239
    timers = get_timers()
240
241

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
242
    timers('batch-generator', log_level=2).start()
xingjinliang's avatar
xingjinliang committed
243
244
245
246
    global stimer
    with stimer(bdata=True):
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
            data_iterator)
mohammad's avatar
mohammad committed
247
    timers('batch-generator').stop()
248

xingjinliang's avatar
xingjinliang committed
249
250
251
    with stimer:
        output_tensor = model(tokens, position_ids, attention_mask,
                              labels=labels)
252

253
    return output_tensor, partial(loss_func, loss_mask)
254
255


xingjinliang's avatar
xingjinliang committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def is_dataset_built_on_rank():
    return (
        mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
    ) and mpu.get_tensor_model_parallel_rank() == 0


def core_gpt_dataset_config_from_args(args):
    tokenizer = get_tokenizer()

    # Sometimes --data-path is too long, instead we parse it from a file.
    blend: Optional[Tuple[List[str], Optional[List[float]]]]
    blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
    blend, blend_per_split = get_blend_and_blend_per_split(args)

    return GPTDatasetConfig(
        random_seed=args.seed,
272
        sequence_length=args.seq_length + args.num_nextn_predict_layers,
xingjinliang's avatar
xingjinliang committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        blend=blend,
        blend_per_split=blend_per_split,
        renormalize_blend_weights=args.renormalize_blend_weights,
        split=args.split,
        num_dataset_builder_threads=args.num_dataset_builder_threads,
        path_to_cache=args.data_cache_path,
        mmap_bin_files=args.mmap_bin_files,
        tokenizer=tokenizer,
        reset_position_ids=args.reset_position_ids,
        reset_attention_mask=args.reset_attention_mask,
        eod_mask_loss=args.eod_mask_loss,
        create_attention_mask=args.create_attention_mask_in_dataloader,
        s3_cache_path=args.s3_cache_path,
    )


289
def train_valid_test_datasets_provider(train_val_test_num_samples):
xingjinliang's avatar
xingjinliang committed
290
291
292
293
294
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
Mohammad's avatar
Mohammad committed
295
    args = get_args()
Mohammad's avatar
Mohammad committed
296

xingjinliang's avatar
xingjinliang committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    config = core_gpt_dataset_config_from_args(args)

    if args.mock_data:
        dataset_type = MockGPTDataset
    else:
        dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        dataset_type,
        train_val_test_num_samples,
        is_dataset_built_on_rank,
        config
    ).build()

313
    print_rank_0("> finished creating GPT datasets ...")
314

315
    return train_ds, valid_ds, test_ds
316
317
318


if __name__ == "__main__":
319

xingjinliang's avatar
xingjinliang committed
320
321
322
323
324
325
326
327
328
329
    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        ModelType.encoder_or_decoder,
        forward_step,
        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
    )