test_upcycling.py 6.61 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import sys

import pytest
import torch
import torch.distributed

from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_layer_with_transformer_engine_spec as gpt_te_spec,
)
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.moe import upcycling_utils
from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args
from megatron.training.global_vars import (
    destroy_global_vars,
    get_args,
    set_args,
    set_global_variables,
)
from megatron.training.training import get_model, setup_model_and_optimizer
from megatron.training.utils import (
    get_batch_on_this_cp_rank,
    get_batch_on_this_tp_rank,
    unwrap_model,
)
from tests.unit_tests.test_utilities import Utils

_SEED = 42


def model_provider(pre_process=True, post_process=True, layer_spec_fn=gpt_te_spec, **config_kwargs):
    model_parallel_cuda_manual_seed(_SEED)
    args = get_args()

    config = core_transformer_config_from_args(args)

    model = GPTModel(
        config=config,
        transformer_layer_spec=gpt_te_spec(
            args.num_experts, args.moe_grouped_gemm, args.qk_layernorm
        ),
        vocab_size=args.vocal_size,
        max_sequence_length=args.max_position_embeddings,
        pre_process=pre_process,
        post_process=post_process,
        fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
        parallel_output=True,
        share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
        position_embedding_type=args.position_embedding_type,
        rotary_percent=args.rotary_percent,
    )

    return model


def create_test_args(
    tensor_model_parallel_size, pipeline_model_parallel_size, enable_vp, enable_grouped_gemm
):
    destroy_global_vars()
    destroy_num_microbatches_calculator()

    sys.argv = ['test_upcycling.py']
    args = parse_args()
    args.num_layers = 2
    args.vocal_size = 256
    args.hidden_size = 128
    args.num_attention_heads = 8
    args.max_position_embeddings = 256
    args.micro_batch_size = 1
    args.create_attention_mask_in_dataloader = True
    args.seq_length = 256
    args.pipeline_model_parallel_size = pipeline_model_parallel_size
    args.tensor_model_parallel_size = tensor_model_parallel_size
    args.context_parallel_size = 1
    args.num_experts = None
    args.train_iters = 1
    if enable_vp:
        args.num_layers_per_virtual_pipeline_stage = 1
    args.ckpt_format = 'torch_dist'
    args.moe_router_topk = 2
    args.moe_router_pre_softmax = False
    args.moe_token_dispatcher_type = "alltoall"
    args.lr = 3e-5
    args.attention_dropout = 0.0
    args.hidden_dropout = 0.0
    args.async_tensor_model_parallel_allreduce = False
    args.no_save_optim = True
    args.no_load_optim = True
    args.no_load_rng = True
    args.moe_grouped_gemm = enable_grouped_gemm
    args.add_bias_linear = False

    validate_args(args)
    set_global_variables(args, False)
    return args


def set_upcycling_args(enable_grouped_gemm, ep):
    args = get_args()
    args.moe_use_upcycling = True
    args.num_experts = 2
    args.moe_grouped_gemm = enable_grouped_gemm
    args.expert_model_parallel_size = ep
    set_args(args)


def get_batch(data_iterator):
    if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
        return None, None, None, None, None

    batch = get_batch_on_this_tp_rank(data_iterator)
    batch = get_batch_on_this_cp_rank(batch)

    return batch.values()


class TestGPTModel:
    def setup_method(self, method):
        Utils.destroy_model_parallel()

    def teardown_method(self, method):
        Utils.destroy_model_parallel()
        destroy_global_vars()
        destroy_num_microbatches_calculator()

    @pytest.mark.internal
    @pytest.mark.skipif(True, reason="The test is flaky")  # TODO: Fix the test
    @pytest.mark.parametrize(
        ('tp_pp_ep', 'enable_vp', 'enable_grouped_gemm'), [((1, 1, 2), (False), (False))]
    )
    def test_upcycling(self, tp_pp_ep, enable_vp, enable_grouped_gemm):
        tp = tp_pp_ep[0]
        pp = tp_pp_ep[1]
        ep = tp_pp_ep[2]
        args = create_test_args(tp, pp, enable_vp, enable_grouped_gemm)
        set_args(args)

        torch.manual_seed(_SEED)
        Utils.initialize_model_parallel(
            tensor_model_parallel_size=tp,
            pipeline_model_parallel_size=pp,
            virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
        )

        dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
            model_provider, ModelType.encoder_or_decoder
        )

        Utils.destroy_model_parallel()
        Utils.initialize_model_parallel(
            tensor_model_parallel_size=tp,
            pipeline_model_parallel_size=pp,
            expert_model_parallel_size=ep,
            virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
        )
        set_upcycling_args(enable_grouped_gemm, ep)
        # model_parallel_cuda_manual_seed(_SEED+1)
        moe_model = get_model(model_provider, ModelType.encoder_or_decoder)

        # Upcycle the dense model to the MoE model
        moe_model = unwrap_model(moe_model)
        dense_model = unwrap_model(dense_model)

        data = list(range(args.seq_length))
        input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda()
        position_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda()
        )
        attention_mask = torch.ones(
            (args.micro_batch_size, 1, args.seq_length, args.seq_length), dtype=bool
        ).cuda()

        dense_logits = dense_model[0].forward(
            input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask
        )

        state_dict = upcycling_utils.upcycle_state_dict(moe_model, dense_model)
        if len(moe_model) == 1:
            moe_model[0].load_state_dict(state_dict['model'], strict=True)
        else:
            for i in range(len(moe_model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                moe_model[i].load_state_dict(state_dict['model%d' % i], strict=True)

        moe_logits = moe_model[0].forward(
            input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask
        )

        torch.allclose(dense_logits, moe_logits, rtol=1e-03, atol=1e-03)