utils.py 4.43 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from functools import partial
from types import SimpleNamespace
from unittest import mock

import torch

from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed
from megatron.core.transformer import TransformerConfig
from megatron.training.training import get_model
from megatron.training.utils import unwrap_model

NUM_LAYERS = 8
HIDDEN_SIZE = 16
NUM_ATTENTION_HEADS = 8


def initialize_gpt_model(
    pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs
):
    torch.manual_seed(seed)
    model_parallel_cuda_manual_seed(seed)

    default_config_kwargs = dict(
        num_layers=NUM_LAYERS,
        hidden_size=HIDDEN_SIZE,
        num_attention_heads=NUM_ATTENTION_HEADS,
        use_cpu_initialization=True,
    )
    default_config_kwargs.update(**config_kwargs)
    transformer_config = TransformerConfig(**default_config_kwargs, gated_linear_unit=use_glu)
    model = GPTModel(
        config=transformer_config,
        transformer_layer_spec=get_gpt_layer_local_spec(),
        vocab_size=128,
        max_sequence_length=4,
        pre_process=pre_process,
        post_process=post_process,
    )

    model.bfloat16()
    with torch.no_grad():
        for p in model.parameters():
            p.random_()
    return model


def init_basic_mock_args(args, tp, pp, bf16=True):
    args.data_parallel_random_init = False
    args.virtual_pipeline_model_parallel_size = None
    args.fp16 = False
    args.bf16 = bf16
    args.accumulate_allreduce_grads_in_fp32 = False
    args.overlap_grad_reduce = False
    args.overlap_param_gather_with_optimizer_step = False
    args.fp8_param_gather = False
    args.use_distributed_optimizer = True
    args.ddp_bucket_size = None
    args.check_for_nan_in_loss_and_grad = False
    args.ddp_average_in_collective = False
    args.tensor_model_parallel_size = tp
    args.pipeline_model_parallel_size = pp
    args.enable_ft_package = False
    return args


def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False):
    args.non_persistent_global_ckpt_dir = None
    args.non_persistent_ckpt_type = None
    args.save = ckpt_dir
    args.load = ckpt_dir
    args.pretrained_checkpoint = None
    args.ckpt_fully_parallel_save = fully_parallel
    args.ckpt_fully_parallel_load = fully_parallel
    args.async_save = False
    args.use_dist_ckpt = True
    args.ckpt_format = 'torch_dist'
    args.no_save_optim = False
    args.no_save_rng = False
    args.ckpt_assume_constant_structure = False
    args.log_progress = False
    args.auto_detect_ckpt_format = False
    args.exit_on_missing_checkpoint = False
    args.finetune = False
    args.consumed_train_samples = 0
    args.skipped_train_samples = 0
    args.consumed_valid_samples = 0
    args.retro_add_retriever = False
    args.no_load_optim = False
    args.no_load_rng = False
    args.dist_ckpt_strictness = 'assume_ok_unexpected'
    args.add_position_embedding = True
    args.vocab_file = False
    args.num_layers = NUM_LAYERS
    args.hidden_size = HIDDEN_SIZE
    args.num_attention_heads = NUM_ATTENTION_HEADS


def setup_model_and_optimizer(
    seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True
):
    mock_args = SimpleNamespace()
    with mock.patch('megatron.training.training.get_args', new=lambda: mock_args):
        init_basic_mock_args(mock_args, tp, pp, bf16=bf16)
        model = get_model(
            partial(
                initialize_fn,
                seed=seed,
                tensor_model_parallel_size=tp,
                pipeline_model_parallel_size=pp,
                pipeline_dtype=torch.bfloat16,
            )
        )

    config = OptimizerConfig(
        bf16=bf16,
        params_dtype=torch.bfloat16 if bf16 else torch.float,
        use_distributed_optimizer=dist_opt,
    )
    optimizer = get_megatron_optimizer(config, model)

    torch.manual_seed(seed + 1)
    model_parallel_cuda_manual_seed(seed + 1)

    for group in optimizer.optimizer.param_groups:
        for p in group['params']:
            if len(optimizer.optimizer.state[p]) == 0:
                optimizer.optimizer.state[p]['exp_avg'] = torch.rand_like(p.data)
                optimizer.optimizer.state[p]['exp_avg_sq'] = torch.rand_like(p.data)

    optimizer.reload_model_params()

    return unwrap_model(model), optimizer