patch.py 5.97 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
Rick Ho's avatar
Rick Ho committed
4
5
import torch

zms1999's avatar
zms1999 committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def patch_loss_func_v2_5(loss_func):
    r"""
    Patch model's loss_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron.mpu import get_tensor_model_parallel_group
    from megatron import get_args
    from megatron import get_num_microbatches

    if not get_args().balance_strategy:
        return loss_func

    def loss_func_with_balance_loss(model, output_tensor):
        args = get_args()
        assert args.balance_strategy, "Only use patched loss_func when having balance_strategy."
        assert is_pipeline_last_stage(), "Only call loss_func at pipeline last stage."
        
        output = loss_func(output_tensor)
        
        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
                for l in model.language_model.encoder.layers
                if l.mlp.gate.has_loss]
Rick Ho's avatar
Rick Ho committed
32

zms1999's avatar
zms1999 committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        if hasattr(model.language_model, "decoder"):
            loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1)
                    for l in model.language_model.decoder.layers
                    if l.mlp.gate.has_loss]
            loss_list.append(loss_list_decoder)
            
        if len(loss_list) == 0:
            return output

        loss_name = args.balance_strategy + "_loss"
        (loss, state_dict), bal_loss = (
            output,
            torch.cat(loss_list).mean() * args.balance_loss_weight / args.pipeline_model_parallel_size
        )

        bal_loss = bal_loss / get_num_microbatches()

        # avarage across moe group
        moe_group = get_tensor_model_parallel_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

    return loss_func_with_balance_loss

def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
Rick Ho's avatar
Rick Ho committed
65
66
67
68
69
70
71
72
73
74
75
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron.mpu import get_tensor_model_parallel_group
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

zms1999's avatar
zms1999 committed
76
    def forward_step_with_balance_loss_v2_2(data_iterator, model, input_tensor):
Rick Ho's avatar
Rick Ho committed
77
78
79
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

Rick Ho's avatar
Rick Ho committed
80
        if not is_pipeline_last_stage() or not args.balance_strategy:
Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
            return output

        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
92
                for l in model.language_model.transformer.layers
                if l.mlp.gate.has_loss]
        if len(loss_list) == 0:
            return output

        loss_name = args.balance_strategy + "_loss"
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        (loss, state_dict), bal_loss = (
            output,
            torch.cat(loss_list).mean() * args.balance_loss_weight
        )

        # avarage across moe group
        moe_group = get_tensor_model_parallel_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

zms1999's avatar
zms1999 committed
110
111
112
113
114
115
116
117
118
119
120
121
122
    def forward_step_with_balance_loss_v2_5(data_iterator, model):
        from functools import partial
        output, loss_func = forward_step_func(data_iterator, model)
    
        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
                for l in model.language_model.encoder.layers
                if l.mlp.gate.has_loss]

        bal_loss = torch.cat(loss_list).mean() * get_args().balance_loss_weight / get_args().pipeline_model_parallel_size
        return output, partial(patch_loss_func_v2_5(loss_func), model), bal_loss
Rick Ho's avatar
Rick Ho committed
123

zms1999's avatar
zms1999 committed
124
125
126
127
128
129
    if Megatron_Version == "v2.2":
        return forward_step_with_balance_loss_v2_2
    elif Megatron_Version == "v2.5":
        return forward_step_with_balance_loss_v2_5
    else:
        assert False, f"megatron version {Megatron_Version} not known."
Rick Ho's avatar
Rick Ho committed
130

zms1999's avatar
zms1999 committed
131
132
133


def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
Rick Ho's avatar
Rick Ho committed
134
135
    from megatron import get_args

zms1999's avatar
zms1999 committed
136
    def fmoefied_model_provider_v2_2():
Rick Ho's avatar
Rick Ho committed
137
138
        from .layers import fmoefy
        args = get_args()
Rick Ho's avatar
Rick Ho committed
139
        hhs = args.hidden_size * 4
Rick Ho's avatar
Rick Ho committed
140
        assert hhs % args.top_k == 0
Rick Ho's avatar
Rick Ho committed
141
        hhs = hhs // args.top_k
Rick Ho's avatar
Rick Ho committed
142
143
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
Rick Ho's avatar
Rick Ho committed
144
145
146
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
Rick Ho's avatar
Rick Ho committed
147
            hidden_hidden_size=hhs,
Rick Ho's avatar
Rick Ho committed
148
            top_k=args.top_k,
zms1999's avatar
zms1999 committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            gate=gate,
            megatron_version="v2.2"
        )
    
    def fmoefied_model_provider_v2_5(pre_process, post_process):
        from .layers import fmoefy
        args = get_args()
        hhs = args.hidden_size * 4
        assert hhs % args.top_k == 0
        hhs = hhs // args.top_k
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
        return fmoefy(
            model_provider(pre_process=pre_process, post_process=post_process),
            num_experts=args.num_experts,
            hidden_hidden_size=hhs,
            top_k=args.top_k,
            gate=gate,
            megatron_version="v2.5"
Rick Ho's avatar
Rick Ho committed
168
169
        )

zms1999's avatar
zms1999 committed
170
171
172
173
174
175
    if Megatron_Version == 'v2.2':
        return fmoefied_model_provider_v2_2
    elif Megatron_Version == 'v2.5':
        return fmoefied_model_provider_v2_5
    else:
        assert False, f"Megatron Version {Megatron_Version} unknown."