Commit 5cb4e63c authored by Rick Ho's avatar Rick Ho
Browse files

add inspection of input.grad in tests

parent 27c89b5a
......@@ -95,6 +95,7 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
if __name__ == '__main__':
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0')
if int(os.environ['WORLD_SIZE']) > 1:
torch.distributed.init_process_group(backend='nccl')
rank = torch.distributed.get_rank()
......
......@@ -29,15 +29,19 @@ def _perform_forward(
moe.gate.gate.bias.data, group_sender, group=mp_group
)
gate_idx, gate_score = moe.gate(inp)
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean()
raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()
inp_raw = inp.clone()
inp.requires_grad = True
moe_out.backward()
raw_out.backward()
inp_raw.requires_grad = True
gate_idx, gate_score = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
return moe_out, raw_out
raw_out.mean().backward()
moe_out.mean().backward()
return moe_out, raw_out, inp.grad, inp_raw.grad
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
......@@ -128,12 +132,12 @@ def test_fmoe_linear(
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out = _perform_forward(
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
moe_out_list = moe_out, moe_grad_in, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
if world_size > 1:
_, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
......@@ -148,7 +152,7 @@ def test_fmoe_linear(
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
......@@ -215,7 +219,7 @@ def test_fmoe(
idx
].data = para_tensor_gathered[expertID]
moe_out, raw_out = _perform_forward(
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
......@@ -242,31 +246,31 @@ def test_fmoe(
mp_size = mp_group.size() if mp_group else 1
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
moe_out_list = [moe_out, moe_grad]
raw_out_list = [raw_out, raw_grad]
names = ["forward", "backward"]
moe_out_list = [moe_out, moe_grad, moe_grad_in]
raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
if __name__ == "__main__":
test_fmoe_linear(
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
d_hidden=16,
rank=0,
world_size=1,
mp_group=None,
)
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
d_hidden=16,
rank=0,
world_size=1,
mp_group=None,
)
test_fmoe(
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
expert=NaiveExpert,
rank=0,
world_size=1,
mp_group=None,
)
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
expert=NaiveExpert,
rank=0,
world_size=1,
mp_group=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment