Commit cff5452a authored by 王敏's avatar 王敏
Browse files

[fix]删掉错误添加代码

parent 8db76782
...@@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe_single) torch_moe_single)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
...@@ -30,9 +29,6 @@ NUM_EXPERTS = [8, 64] ...@@ -30,9 +29,6 @@ NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
...@@ -71,7 +67,6 @@ def test_fused_moe( ...@@ -71,7 +67,6 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a, iterative_output = iterative_moe(a,
w1, w1,
...@@ -97,7 +92,6 @@ def test_fused_moe( ...@@ -97,7 +92,6 @@ def test_fused_moe(
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment