balance.py 1.32 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn.functional as F

metrics = {
    "coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
Sengxian's avatar
Sengxian committed
6
7
    "Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
    "Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
8
9
10
11
12
13
14
15
16
17
18
19
20
21
}


def reset_balance_profile(balance_dict, num_layers, balance_strategy):
    for key in metrics:
        balance_dict[key] = [None for _ in range(num_layers)]
    if balance_strategy:
        balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)]


def update_balance_profile(
    balance_dict,
    gate_top_k_idx,
    _gate_score_top_k,
Sengxian's avatar
Sengxian committed
22
    gate_context,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    layer_idx,
    num_expert,
    balance_strategy,
):
    c_e = torch.scatter_add(
        torch.zeros(num_expert, device=gate_top_k_idx.device),
        0,
        gate_top_k_idx,
        torch.ones_like(gate_top_k_idx, dtype=torch.float),
    )
    for key in metrics:
        balance_dict[key][layer_idx] = metrics[key](c_e)
    S = gate_top_k_idx.shape[0]
    if balance_strategy == "gshard":
Sengxian's avatar
Sengxian committed
37
        gate_score_all = gate_context
38
39
40
        m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
        balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
    elif balance_strategy == "noisy":
Sengxian's avatar
Sengxian committed
41
        balance_dict["noisy_loss"][layer_idx] = gate_context