"...extensions/KeyTable/js/dataTables.keyTable.min.js" did not exist on "9eeb94e8b531278e0769618d57d5d4538a2fabb4"
balance.py 3.9 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
r"""
Support for monitoring loss in Megatron
"""
import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
8
from .distributed import get_moe_group
Rick Ho's avatar
Rick Ho committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


balance_dict = {}
num_layers = 0


def reset_gate_hook(_num_layers=None):
    from megatron import get_args

    global balance_dict, num_layers
    if _num_layers is not None:
        num_layers = _num_layers
    reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)


def get_balance_profile():
    global balance_dict
    return balance_dict


def generate_megatron_gate_hook(layer_idx, num_expert_global):
    from megatron import get_args

    balance_strategy = get_args().balance_strategy

    def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
        global balance_dict
        update_balance_profile(
            balance_dict,
            gate_top_k_idx,
            gate_score_top_k,
            gate_context,
            layer_idx,
            num_expert_global,
            balance_strategy,
        )

    return megatron_gate_hook


49
def add_balance_log(model, writer, iteration):
Rick Ho's avatar
Rick Ho committed
50
51
    from megatron import is_last_rank

52
    while hasattr(model, 'module'):
53
54
        model = model.module

Rick Ho's avatar
Rick Ho committed
55
    balance_dict_tensor = torch.vstack(
56
        [l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers]
Rick Ho's avatar
Rick Ho committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    ).detach()
    world_group = get_torch_default_comm()
    world_size = torch.distributed.get_world_size(group=world_group)
    torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
    balance_dict_tensor /= world_size

    if writer and is_last_rank():
        for idx, metric_name in enumerate(balance_dict):
            for layer_id, val in enumerate(balance_dict_tensor[idx]):
                writer.add_scalar(
                    f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
                )
            writer.add_scalar(
                f"balance-{metric_name}/all",
                balance_dict_tensor[idx].mean().item(),
                iteration,
            )


def patch_forward_step(forward_step_func):
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

91
        if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
Rick Ho's avatar
Rick Ho committed
92
93
94
            return output
        loss_name = args.balance_strategy + "_loss"

95
        while hasattr(model, 'module'):
96
97
            model = model.module

Rick Ho's avatar
Rick Ho committed
98
99
        loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
                for l in model.language_model.transformer.layers]
Rick Ho's avatar
Rick Ho committed
100
101
        (loss, state_dict), bal_loss = (
            output,
102
            torch.cat(loss_list).mean() * args.balance_loss_weight
Rick Ho's avatar
Rick Ho committed
103
104
        )

105
106
107
        # avarage across moe group
        moe_group = get_moe_group()
        world_size = torch.distributed.get_world_size(group=moe_group)
Rick Ho's avatar
Rick Ho committed
108
        averaged_bal_loss = bal_loss.clone().detach()
109
        torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
Rick Ho's avatar
Rick Ho committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

    return forward_step_with_balance_loss


def patch_model_provider(model_provider):
    from megatron import get_args

    def fmoefied_model_provider():
        from .layers import fmoefy
        args = get_args()
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
            hidden_hidden_size=4 * args.hidden_size // args.top_k,
            top_k=args.top_k,
        )

    return fmoefied_model_provider