Commit 1b7817a5 authored by Sengxian's avatar Sengxian Committed by Jiezhong Qiu
Browse files

Revise variable name

parent 038b31ea
......@@ -3,8 +3,8 @@ import torch.nn.functional as F
metrics = {
"coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
"Lmax_div_Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
"Lmax_div_Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
"Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
"Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
}
......@@ -19,7 +19,7 @@ def update_balance_profile(
balance_dict,
gate_top_k_idx,
_gate_score_top_k,
gate_state_dict,
gate_context,
layer_idx,
num_expert,
balance_strategy,
......@@ -34,8 +34,8 @@ def update_balance_profile(
balance_dict[key][layer_idx] = metrics[key](c_e)
S = gate_top_k_idx.shape[0]
if balance_strategy == "gshard":
gate_score_all = gate_state_dict
gate_score_all = gate_context
m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
elif balance_strategy == "noisy":
balance_dict["noisy_loss"][layer_idx] = gate_state_dict
balance_dict["noisy_loss"][layer_idx] = gate_context
......@@ -96,13 +96,13 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
balance_strategy = get_args().balance_strategy
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_state_dict):
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
global balance_dict
update_balance_profile(
balance_dict,
gate_top_k_idx,
gate_score_top_k,
gate_state_dict,
gate_context,
layer_idx,
num_expert_global,
balance_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment