balance.py 786 Bytes
Newer Older
1
2
3
4
5
import torch
import torch.nn.functional as F

metrics = {
    "coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
Sengxian's avatar
Sengxian committed
6
7
    "Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
    "Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
8
9
10
11
12
13
14
15
16
17
18
19
20
21
}


def reset_balance_profile(balance_dict, num_layers, balance_strategy):
    for key in metrics:
        balance_dict[key] = [None for _ in range(num_layers)]
    if balance_strategy:
        balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)]


def update_balance_profile(
    balance_dict,
    gate_top_k_idx,
    _gate_score_top_k,
Sengxian's avatar
Sengxian committed
22
    gate_context,
23
24
25
26
    layer_idx,
    num_expert,
    balance_strategy,
):
27
28
    # Fill in this function to conduct balance related jobs
    pass