Commit 69151519 authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Moved test into numerical. Handled different precision for half tensors

parent f5408e3c
import torch
import pytest
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_data_support(
num_expert,
top_k,
batch_size,
d_model,
d_hidden,
data_type,
activation=torch.nn.functional.gelu,
):
"""
The objective of this test is to make sure that the cuda
kernels for forward/backward handle different data types
without crashing
"""
moe = MyMoE(
num_expert, d_model, d_hidden, top_k, activation
).type(data_type).cuda()
inp = torch.rand(batch_size, d_model).type(data_type).cuda()
moe(inp).sum().backward()
\ No newline at end of file
...@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert ...@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def _perform_forward( def _perform_forward(
moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
): ):
moe.zero_grad() moe.zero_grad()
moe_raw.zero_grad() moe_raw.zero_grad()
if not mp_group:
inp = torch.rand(batch_size, d_model).cuda() inp = torch.rand(batch_size, d_model).type(data_type).cuda()
else:
if mp_group:
group_sender = rank // mp_group.size() * mp_group.size() group_sender = rank // mp_group.size() * mp_group.size()
inp = torch.rand(batch_size, d_model).cuda()
torch.distributed.broadcast(inp, group_sender, group=mp_group) torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast( torch.distributed.broadcast(
moe.gate.gate.weight.data, group_sender, group=mp_group moe.gate.gate.weight.data, group_sender, group=mp_group
...@@ -49,15 +49,17 @@ def _perform_forward( ...@@ -49,15 +49,17 @@ def _perform_forward(
return moe_out, raw_out, inp.grad, inp_raw.grad return moe_out, raw_out, inp.grad, inp_raw.grad
def _assert_numercial(names, moe_out_list, raw_out_list, rank): def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list): for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print("Rank {} {} abs err {}".format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3: if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n") sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo)) sys.stderr.write("{}\n".format(mo))
sys.stderr.write(f"=========== {name} raw out ==============\n") sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro)) sys.stderr.write("{}\n".format(ro))
sys.stderr.write(f"=========== {name} diff ==============\n")
sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
assert False assert False
...@@ -90,6 +92,7 @@ class MyMoE(FMoE): ...@@ -90,6 +92,7 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear( def test_fmoe_linear(
num_expert, num_expert,
top_k, top_k,
...@@ -101,6 +104,7 @@ def test_fmoe_linear( ...@@ -101,6 +104,7 @@ def test_fmoe_linear(
mp_group, mp_group,
dp_group, dp_group,
world_group, world_group,
data_type,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
...@@ -108,7 +112,7 @@ def test_fmoe_linear( ...@@ -108,7 +112,7 @@ def test_fmoe_linear(
moe = MyMoE( moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).cuda() ).type(data_type).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -117,7 +121,7 @@ def test_fmoe_linear( ...@@ -117,7 +121,7 @@ def test_fmoe_linear(
d_hidden=d_hidden, d_hidden=d_hidden,
world_size=world_size, world_size=world_size,
top_k=top_k, top_k=top_k,
).cuda() ).type(data_type).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
...@@ -148,7 +152,7 @@ def test_fmoe_linear( ...@@ -148,7 +152,7 @@ def test_fmoe_linear(
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0) moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward( moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
) )
moe_out_list = ( moe_out_list = (
...@@ -198,7 +202,10 @@ def test_fmoe_linear( ...@@ -198,7 +202,10 @@ def test_fmoe_linear(
"h4toh bias grad", "h4toh bias grad",
] ]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
_assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -299,7 +306,7 @@ def test_fmoe( ...@@ -299,7 +306,7 @@ def test_fmoe(
raw_out_list = [raw_out, raw_grad, raw_grad_in] raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"] names = ["forward", "backward", "grad_in"]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numerical(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module): class MyModule(nn.Module):
...@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names = ["mp grad", "dp grad", "wp grad"] names = ["mp grad", "dp grad", "wp grad"]
_assert_numercial(names, ddp_out_list, raw_out_list, rank) _assert_numerical(names, ddp_out_list, raw_out_list, rank)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment