Commit 092c8d67 authored by Sengxian's avatar Sengxian
Browse files

fix unpack bug

parent 05728524
......@@ -52,9 +52,9 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
err = (mo - ro).abs().sum()
print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3:
sys.stderr.write("=========== moe out ==============\n")
sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write("=========== raw out ==============\n")
sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
assert False
......@@ -149,7 +149,7 @@ def test_fmoe_linear(
raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
if world_size > 1:
_, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
_, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
torch.distributed.all_reduce(htoh4_w_grad)
torch.distributed.all_reduce(h4toh_w_grad)
torch.distributed.all_reduce(htoh4_b_grad)
......@@ -167,7 +167,7 @@ def test_fmoe_linear(
h4toh_b_grad = (
h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment