"docs/source/en/api/models.mdx" did not exist on "c15cda03ca36a5e344c8f26179c9ef48d3a88c69"
patch.py 2.25 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
Rick Ho's avatar
Rick Ho committed
4
5
6
import torch


Rick Ho's avatar
Rick Ho committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def patch_forward_step(forward_step_func):
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron.mpu import get_tensor_model_parallel_group
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

Rick Ho's avatar
Rick Ho committed
23
        if not is_pipeline_last_stage() or not args.balance_strategy:
Rick Ho's avatar
Rick Ho committed
24
25
26
27
28
29
            return output

        while hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
35
                for l in model.language_model.transformer.layers
                if l.mlp.gate.has_loss]
        if len(loss_list) == 0:
            return output

        loss_name = args.balance_strategy + "_loss"
Rick Ho's avatar
Rick Ho committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        (loss, state_dict), bal_loss = (
            output,
            torch.cat(loss_list).mean() * args.balance_loss_weight
        )

        # avarage across moe group
        moe_group = get_tensor_model_parallel_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

    return forward_step_with_balance_loss


Rick Ho's avatar
Rick Ho committed
56
def patch_model_provider(model_provider, gate=None):
Rick Ho's avatar
Rick Ho committed
57
58
59
60
61
    from megatron import get_args

    def fmoefied_model_provider():
        from .layers import fmoefy
        args = get_args()
Rick Ho's avatar
Rick Ho committed
62
        hhs = args.hidden_size * 4
Rick Ho's avatar
Rick Ho committed
63
        assert hhs % args.top_k == 0
Rick Ho's avatar
Rick Ho committed
64
        hhs = hhs // args.top_k
Rick Ho's avatar
Rick Ho committed
65
66
        assert hhs % args.tensor_model_parallel_size == 0
        hhs = hhs // args.tensor_model_parallel_size
Rick Ho's avatar
Rick Ho committed
67
68
69
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
Rick Ho's avatar
Rick Ho committed
70
            hidden_hidden_size=hhs,
Rick Ho's avatar
Rick Ho committed
71
            top_k=args.top_k,
Rick Ho's avatar
Rick Ho committed
72
            gate=gate
Rick Ho's avatar
Rick Ho committed
73
74
75
        )

    return fmoefied_model_provider