Commit e028f2ec authored by Sengxian's avatar Sengxian
Browse files

Revise variable name

parent 69121432
...@@ -3,8 +3,8 @@ import torch.nn.functional as F ...@@ -3,8 +3,8 @@ import torch.nn.functional as F
metrics = { metrics = {
"coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e), "coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
"Lmax_div_Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1), "Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
"Lmax_div_Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e), "Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
} }
...@@ -19,7 +19,7 @@ def update_balance_profile( ...@@ -19,7 +19,7 @@ def update_balance_profile(
balance_dict, balance_dict,
gate_top_k_idx, gate_top_k_idx,
_gate_score_top_k, _gate_score_top_k,
gate_state_dict, gate_context,
layer_idx, layer_idx,
num_expert, num_expert,
balance_strategy, balance_strategy,
...@@ -34,8 +34,8 @@ def update_balance_profile( ...@@ -34,8 +34,8 @@ def update_balance_profile(
balance_dict[key][layer_idx] = metrics[key](c_e) balance_dict[key][layer_idx] = metrics[key](c_e)
S = gate_top_k_idx.shape[0] S = gate_top_k_idx.shape[0]
if balance_strategy == "gshard": if balance_strategy == "gshard":
gate_score_all = gate_state_dict gate_score_all = gate_context
m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
elif balance_strategy == "noisy": elif balance_strategy == "noisy":
balance_dict["noisy_loss"][layer_idx] = gate_state_dict balance_dict["noisy_loss"][layer_idx] = gate_context
...@@ -96,13 +96,13 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global): ...@@ -96,13 +96,13 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
balance_strategy = get_args().balance_strategy balance_strategy = get_args().balance_strategy
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_state_dict): def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
global balance_dict global balance_dict
update_balance_profile( update_balance_profile(
balance_dict, balance_dict,
gate_top_k_idx, gate_top_k_idx,
gate_score_top_k, gate_score_top_k,
gate_state_dict, gate_context,
layer_idx, layer_idx,
num_expert_global, num_expert_global,
balance_strategy, balance_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment