patch.py 6.83 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
Rick Ho's avatar
Rick Ho committed
4
5
import torch

zms1999's avatar
zms1999 committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def patch_loss_func_v2_5(loss_func):
    r"""
    Patch model's loss_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron.mpu import get_tensor_model_parallel_group
    from megatron import get_args
    from megatron import get_num_microbatches

    if not get_args().balance_strategy:
        return loss_func

    def loss_func_with_balance_loss(model, output_tensor):
        args = get_args()
        assert args.balance_strategy, "Only use patched loss_func when having balance_strategy."
        assert is_pipeline_last_stage(), "Only call loss_func at pipeline last stage."
        
        output = loss_func(output_tensor)
        
        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
                for l in model.language_model.encoder.layers
                if l.mlp.gate.has_loss]
Rick Ho's avatar
Rick Ho committed
32

Jiezhong Qiu's avatar
Jiezhong Qiu committed
33
        if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
zms1999's avatar
zms1999 committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
            loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1)
                    for l in model.language_model.decoder.layers
                    if l.mlp.gate.has_loss]
            loss_list.append(loss_list_decoder)
            
        if len(loss_list) == 0:
            return output

        loss_name = args.balance_strategy + "_loss"
        (loss, state_dict), bal_loss = (
            output,
            torch.cat(loss_list).mean() * args.balance_loss_weight / args.pipeline_model_parallel_size
        )

        bal_loss = bal_loss / get_num_microbatches()

        # avarage across moe group
        moe_group = get_tensor_model_parallel_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

    return loss_func_with_balance_loss

def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
Rick Ho's avatar
Rick Ho committed
65
66
67
68
69
70
71
72
73
74
75
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron.mpu import get_tensor_model_parallel_group
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

zms1999's avatar
zms1999 committed
76
    def forward_step_with_balance_loss_v2_2(data_iterator, model, input_tensor):
Rick Ho's avatar
Rick Ho committed
77
78
79
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

Rick Ho's avatar
Rick Ho committed
80
        if not is_pipeline_last_stage() or not args.balance_strategy:
Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
            return output

        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
92
                for l in model.language_model.transformer.layers
                if l.mlp.gate.has_loss]
        if len(loss_list) == 0:
            return output

        loss_name = args.balance_strategy + "_loss"
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        (loss, state_dict), bal_loss = (
            output,
            torch.cat(loss_list).mean() * args.balance_loss_weight
        )

        # avarage across moe group
        moe_group = get_tensor_model_parallel_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

zms1999's avatar
zms1999 committed
110
111
112
113
114
115
116
117
118
119
120
121
122
    def forward_step_with_balance_loss_v2_5(data_iterator, model):
        from functools import partial
        output, loss_func = forward_step_func(data_iterator, model)
    
        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
                for l in model.language_model.encoder.layers
                if l.mlp.gate.has_loss]

        bal_loss = torch.cat(loss_list).mean() * get_args().balance_loss_weight / get_args().pipeline_model_parallel_size
        return output, partial(patch_loss_func_v2_5(loss_func), model), bal_loss
Rick Ho's avatar
Rick Ho committed
123

zms1999's avatar
zms1999 committed
124
125
126
127
    if Megatron_Version == "v2.2":
        return forward_step_with_balance_loss_v2_2
    elif Megatron_Version == "v2.5":
        return forward_step_with_balance_loss_v2_5
Jiezhong Qiu's avatar
Jiezhong Qiu committed
128
129
    elif Megatron_Version == "v3.0.2":
        return forward_step_with_balance_loss_v2_5
zms1999's avatar
zms1999 committed
130
131
    else:
        assert False, f"megatron version {Megatron_Version} not known."
Rick Ho's avatar
Rick Ho committed
132

zms1999's avatar
zms1999 committed
133
134
135


def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
Rick Ho's avatar
Rick Ho committed
136
137
    from megatron import get_args

zms1999's avatar
zms1999 committed
138
    def fmoefied_model_provider_v2_2():
Rick Ho's avatar
Rick Ho committed
139
140
        from .layers import fmoefy
        args = get_args()
Rick Ho's avatar
Rick Ho committed
141
        hhs = args.hidden_size * 4
Rick Ho's avatar
Rick Ho committed
142
        assert hhs % args.top_k == 0
Rick Ho's avatar
Rick Ho committed
143
        hhs = hhs // args.top_k
Rick Ho's avatar
Rick Ho committed
144
145
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
Rick Ho's avatar
Rick Ho committed
146
147
        return fmoefy(
            model_provider(),
Jiezhong Qiu's avatar
Jiezhong Qiu committed
148
            fmoe_num_experts=args.fmoe_num_experts,
Rick Ho's avatar
Rick Ho committed
149
            hidden_hidden_size=hhs,
Rick Ho's avatar
Rick Ho committed
150
            top_k=args.top_k,
zms1999's avatar
zms1999 committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
            gate=gate,
            megatron_version="v2.2"
        )
    
    def fmoefied_model_provider_v2_5(pre_process, post_process):
        from .layers import fmoefy
        args = get_args()
        hhs = args.hidden_size * 4
        assert hhs % args.top_k == 0
        hhs = hhs // args.top_k
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
        return fmoefy(
            model_provider(pre_process=pre_process, post_process=post_process),
Jiezhong Qiu's avatar
Jiezhong Qiu committed
165
            fmoe_num_experts=args.fmoe_num_experts,
zms1999's avatar
zms1999 committed
166
167
168
169
            hidden_hidden_size=hhs,
            top_k=args.top_k,
            gate=gate,
            megatron_version="v2.5"
Rick Ho's avatar
Rick Ho committed
170
        )
Jiezhong Qiu's avatar
Jiezhong Qiu committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    
    def fmoefied_model_provider_v3_0_2(pre_process, post_process):
        from .layers import fmoefy
        args = get_args()
        hhs = args.hidden_size * 4
        assert hhs % args.top_k == 0
        hhs = hhs // args.top_k
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
        return fmoefy(
            model_provider(pre_process=pre_process, post_process=post_process),
            fmoe_num_experts=args.fmoe_num_experts,
            hidden_hidden_size=hhs,
            top_k=args.top_k,
            gate=gate,
            megatron_version="v3.0.2"
        )
Rick Ho's avatar
Rick Ho committed
188

zms1999's avatar
zms1999 committed
189
190
191
192
    if Megatron_Version == 'v2.2':
        return fmoefied_model_provider_v2_2
    elif Megatron_Version == 'v2.5':
        return fmoefied_model_provider_v2_5
Jiezhong Qiu's avatar
Jiezhong Qiu committed
193
194
    elif Megatron_Version == 'v3.0.2':
        return fmoefied_model_provider_v3_0_2
zms1999's avatar
zms1999 committed
195
196
    else:
        assert False, f"Megatron Version {Megatron_Version} unknown."