Commit e36c3c4c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

test bias grad

parent 6acc3e41
...@@ -132,19 +132,23 @@ def test_fmoe_linear( ...@@ -132,19 +132,23 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
) )
moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
if world_size > 1: if world_size > 1:
_, htoh4_grad, h4toh_grad = raw_out_list _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
torch.distributed.all_reduce(htoh4_grad) torch.distributed.all_reduce(htoh4_w_grad)
torch.distributed.all_reduce(h4toh_grad) torch.distributed.all_reduce(h4toh_w_grad)
torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1 mp_size = mp_group.size() if mp_group else 1
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_grad, h4toh_grad htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"] names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numercial(names, moe_out_list, raw_out_list, rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment