balance.py 3.97 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
r"""
Support for monitoring loss in Megatron
"""
import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm


balance_dict = {}
num_layers = 0


def reset_gate_hook(_num_layers=None):
    from megatron import get_args

    global balance_dict, num_layers
    if _num_layers is not None:
        num_layers = _num_layers
    reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)


def get_balance_profile():
    global balance_dict
    return balance_dict


def generate_megatron_gate_hook(layer_idx, num_expert_global):
    from megatron import get_args

    balance_strategy = get_args().balance_strategy

    def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
        global balance_dict
        update_balance_profile(
            balance_dict,
            gate_top_k_idx,
            gate_score_top_k,
            gate_context,
            layer_idx,
            num_expert_global,
            balance_strategy,
        )

    return megatron_gate_hook


48
def add_balance_log(model, writer, iteration):
Rick Ho's avatar
Rick Ho committed
49
50
    from megatron import is_last_rank

51
52
53
    if hasattr(model, 'module'):
        model = model.module

Rick Ho's avatar
Rick Ho committed
54
    balance_dict_tensor = torch.vstack(
55
        [l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers]
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    ).detach()
    world_group = get_torch_default_comm()
    world_size = torch.distributed.get_world_size(group=world_group)
    torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
    balance_dict_tensor /= world_size

    if writer and is_last_rank():
        for idx, metric_name in enumerate(balance_dict):
            for layer_id, val in enumerate(balance_dict_tensor[idx]):
                writer.add_scalar(
                    f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
                )
            writer.add_scalar(
                f"balance-{metric_name}/all",
                balance_dict_tensor[idx].mean().item(),
                iteration,
            )


def patch_forward_step(forward_step_func):
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

90
        if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
Rick Ho's avatar
Rick Ho committed
91
92
93
            return output
        loss_name = args.balance_strategy + "_loss"

94
95
96
97
        if hasattr(model, 'module'):
            model = model.module

        loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers]
Rick Ho's avatar
Rick Ho committed
98
99
100
101
        (loss, state_dict), bal_loss = (
            output,
            (
                torch.tensor(
102
                    loss_list, device=loss_list[0].device
Rick Ho's avatar
Rick Ho committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
                ).mean()
                * args.balance_loss_weight
            ).float(),
        )

        # avarage across world group
        world_group = get_torch_default_comm()
        world_size = torch.distributed.get_world_size(group=world_group)
        averaged_bal_loss = bal_loss.clone().detach()
        torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
        averaged_bal_loss /= world_size

        loss += bal_loss
        state_dict[loss_name] = averaged_bal_loss

        return loss, state_dict

    return forward_step_with_balance_loss


def patch_model_provider(model_provider):
    from megatron import get_args

    def fmoefied_model_provider():
        from .layers import fmoefy
        args = get_args()
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
            hidden_hidden_size=4 * args.hidden_size // args.top_k,
            top_k=args.top_k,
        )

    return fmoefied_model_provider