"...text-generation-inference.git" did not exist on "7c2e0af2a61745a36b26fae2c817f608be757a4c"
Commit 41cfe06c authored by Rich Ho's avatar Rich Ho
Browse files

add data_type in test ddp

parent b0704f1d
...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict): ...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear_distributed( def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
): ):
_run_distributed( _run_distributed(
"_test_fmoe_linear", "_test_fmoe_linear",
...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed( ...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model": d_model, "d_model": d_model,
"d_hidden": d_hidden, "d_hidden": d_hidden,
"mp_size": mp_size, "mp_size": mp_size,
"data_type": data_type
}, },
) )
...@@ -120,5 +122,6 @@ if __name__ == "__main__": ...@@ -120,5 +122,6 @@ if __name__ == "__main__":
else: else:
test_fmoe_local_ddp(mp_size=2) test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed( test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2 num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment