Commit 41cfe06c authored by Rich Ho's avatar Rich Ho
Browse files

add data_type in test ddp

parent b0704f1d
......@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size
num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
):
_run_distributed(
"_test_fmoe_linear",
......@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model": d_model,
"d_hidden": d_hidden,
"mp_size": mp_size,
"data_type": data_type
},
)
......@@ -120,5 +122,6 @@ if __name__ == "__main__":
else:
test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment