Commit 40841453 authored by Sengxian's avatar Sengxian
Browse files

Fix top_k=3 testcases

parent baf2b118
...@@ -44,7 +44,7 @@ def _assert_numercial(names, moe_out_list, raw_out_list): ...@@ -44,7 +44,7 @@ def _assert_numercial(names, moe_out_list, raw_out_list):
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
...@@ -80,6 +80,7 @@ def test_fmoe_linear( ...@@ -80,6 +80,7 @@ def test_fmoe_linear(
d_model=d_model, d_model=d_model,
d_hidden=d_hidden, d_hidden=d_hidden,
world_size=world_size, world_size=world_size,
top_k=top_k,
).cuda() ).cuda()
if world_size == 1: if world_size == 1:
...@@ -118,7 +119,7 @@ def test_fmoe_linear( ...@@ -118,7 +119,7 @@ def test_fmoe_linear(
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"]) @pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe( def test_fmoe(
batch_size, num_expert, d_model, top_k, expert: Union[Type[nn.Module], str] batch_size, num_expert, d_model, top_k, expert: Union[Type[nn.Module], str]
...@@ -140,7 +141,11 @@ def test_fmoe( ...@@ -140,7 +141,11 @@ def test_fmoe(
).cuda() ).cuda()
moe_raw = BruteForceMoE( moe_raw = BruteForceMoE(
expert=expert, num_expert=num_expert, d_model=d_model, world_size=world_size, expert=expert,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
top_k=top_k,
).cuda() ).cuda()
if world_size == 1: if world_size == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment