balance.py 1.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn.functional as F

metrics = {
    "coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
    "Lmax_div_Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
    "Lmax_div_Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
}


def reset_balance_profile(balance_dict, num_layers, balance_strategy):
    for key in metrics:
        balance_dict[key] = [None for _ in range(num_layers)]
    if balance_strategy:
        balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)]


def update_balance_profile(
    balance_dict,
    gate_top_k_idx,
    _gate_score_top_k,
    gate_state_dict,
    layer_idx,
    num_expert,
    balance_strategy,
):
    c_e = torch.scatter_add(
        torch.zeros(num_expert, device=gate_top_k_idx.device),
        0,
        gate_top_k_idx,
        torch.ones_like(gate_top_k_idx, dtype=torch.float),
    )
    for key in metrics:
        balance_dict[key][layer_idx] = metrics[key](c_e)
    S = gate_top_k_idx.shape[0]
    if balance_strategy == "gshard":
        gate_score_all = gate_state_dict
        m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
        balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
    elif balance_strategy == "noisy":
        balance_dict["noisy_loss"][layer_idx] = gate_state_dict