"platforms/vscode:/vscode.git/clone" did not exist on "6f24ee8a965137f83409d0f57b4a0f96eb1f2fb4"
Commit 2a75c6bc authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

parent 3dd7fd64
......@@ -49,27 +49,27 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-7B-Instruct")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-72B-Instruct")),
# MRoPETestInfo(model_name=os.path.join("Qwen/Qwen2.5-VL-72B-Instruct")),
MRoPETestInfo(
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MRoPETestInfo(
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
# MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
# MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
# MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
# MRoPETestInfo(
# model_name="Qwen/Qwen3-VL-4B-Instruct",
# marks=[
# pytest.mark.skipif(
# Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
# reason="Qwen3-VL only available after Transformers v4.57",
# )
# ]),
# MRoPETestInfo(
# model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
# marks=[
# pytest.mark.skipif(
# Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
# reason="Qwen3-VL only available after Transformers v4.57",
# )
# ]),
]
num_tokens_list = [11, 8192]
......@@ -78,7 +78,7 @@ num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
pytest.param(test_config, os.path.join(models_path_prefix, test_config.model_name), marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
......@@ -90,7 +90,7 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = config.get_text_config()
# get the model config
......
......@@ -90,7 +90,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 512]) # 224
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16] if not current_platform.is_rocm() else [torch.bfloat16])
......
......@@ -528,305 +528,305 @@ def marlin_moe_generate_valid_test_cases():
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases())
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
quant_type: ScalarType,
is_k_full: bool,
):
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
# @pytest.mark.flaky(reruns=2)
# @pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
# "act_order, quant_type, is_k_full"),
# marlin_moe_generate_valid_test_cases())
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# def test_fused_marlin_moe(
# m: int,
# n: int,
# k: int,
# e: int,
# topk: int,
# ep_size: int,
# dtype: torch.dtype,
# group_size: int,
# act_order: bool,
# quant_type: ScalarType,
# is_k_full: bool,
# ):
# torch.cuda.manual_seed(0)
# has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
w_ref1_l = []
qweight1_l = []
scales1_l = []
global_scale1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
# if ep_size > 1:
# local_e = e // ep_size
# e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
# e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
# e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
# w1 = w1[e_ids]
# w2 = w2[e_ids]
# else:
# e_map = None
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_nvfp4_like(w1[i], group_size)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
global_scale2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_nvfp4_like(w2[i], group_size)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
# w_ref1_l = []
# qweight1_l = []
# scales1_l = []
# global_scale1_l = []
# zeros1_l = []
# g_idx1_l = []
# sort_indices1_l = []
# for i in range(w1.shape[0]):
# if quant_type == scalar_types.float4_e2m1f:
# if group_size == 16:
# w_ref1, qweight1, scales1, global_scale1 = \
# rand_marlin_weight_nvfp4_like(w1[i], group_size)
# else:
# w_ref1, qweight1, scales1 = \
# rand_marlin_weight_mxfp4_like(w1[i], group_size)
# global_scale1 = None
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# if global_scale1 is not None:
# global_scale1_l.append(global_scale1)
# elif quant_type == scalar_types.float8_e4m3fn:
# w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
# w1[i], group_size)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# elif has_zp:
# w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
# w1[i].transpose(1, 0), quant_type, group_size)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# zeros1_l.append(zeros1)
# else:
# test_perm = torch.randperm(k)
# w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
# marlin_quantize(w1[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# g_idx1_l.append(g_idx1)
# sort_indices1_l.append(sort_indices1)
# w_ref1 = stack_and_dev(w_ref1_l)
# qweight1 = stack_and_dev(qweight1_l).contiguous()
# scales1 = stack_and_dev(scales1_l)
# global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
# g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
# zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
# sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
# w_ref2_l = []
# qweight2_l = []
# scales2_l = []
# global_scale2_l = []
# zeros2_l = []
# g_idx2_l = []
# sort_indices2_l = []
# for i in range(w2.shape[0]):
# if quant_type == scalar_types.float4_e2m1f:
# if group_size == 16:
# w_ref2, qweight2, scales2, global_scale2 = \
# rand_marlin_weight_nvfp4_like(w2[i], group_size)
# else:
# w_ref2, qweight2, scales2 = \
# rand_marlin_weight_mxfp4_like(w2[i], group_size)
# global_scale2 = None
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# if global_scale2 is not None:
# global_scale2_l.append(global_scale2)
# elif quant_type == scalar_types.float8_e4m3fn:
# w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
# w2[i], group_size)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# elif has_zp:
# w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
# w2[i].transpose(1, 0), quant_type, group_size)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# zeros2_l.append(zeros2)
# else:
# test_perm = torch.randperm(n)
# w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
# marlin_quantize(w2[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# g_idx2_l.append(g_idx2)
# sort_indices2_l.append(sort_indices2)
# w_ref2 = stack_and_dev(w_ref2_l)
# qweight2 = stack_and_dev(qweight2_l).contiguous()
# scales2 = stack_and_dev(scales2_l)
# global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
# g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
# zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
# sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
# score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
# topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
None,
None,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0)
e, topk = 32, 4
n, k = 2048, 2048
group_size = 128
act_order = False
is_k_full = True
quant_type = scalar_types.uint4b8
dtype = torch.half
# with set_current_vllm_config(vllm_config):
# torch_output = torch_moe(a,
# w_ref1,
# w_ref2,
# score,
# topk,
# expert_map=e_map)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
# marlin_output = torch.ops.vllm.fused_marlin_moe(
# a,
# qweight1,
# qweight2,
# None,
# None,
# scales1,
# scales2,
# score,
# topk_weights,
# topk_ids,
# global_num_experts=e,
# expert_map=e_map,
# global_scale1=global_scale1,
# global_scale2=global_scale2,
# g_idx1=g_idx1,
# g_idx2=g_idx2,
# sort_indices1=sort_indices1,
# sort_indices2=sort_indices2,
# w1_zeros=zeros1,
# w2_zeros=zeros2,
# quant_type_id=quant_type.id,
# is_k_full=is_k_full)
# torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
# @pytest.mark.flaky(reruns=2)
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.parametrize("m", [1, 256])
# def test_fused_marlin_moe_with_bias(m):
# torch.cuda.manual_seed(0)
# e, topk = 32, 4
# n, k = 2048, 2048
# group_size = 128
# act_order = False
# is_k_full = True
# quant_type = scalar_types.uint4b8
# dtype = torch.half
score = torch.randn((m, e), device="cuda", dtype=dtype)
# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
# b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
# b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
# b_bias1_l = []
# w_ref1_l = []
# qweight1_l = []
# scales1_l = []
# g_idx1_l = []
# sort_indices1_l = []
# for i in range(w1.shape[0]):
# test_perm = torch.randperm(k)
# w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
# marlin_quantize(w1[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# g_idx1_l.append(g_idx1)
# sort_indices1_l.append(sort_indices1)
# b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
# w_ref1 = stack_and_dev(w_ref1_l)
# qweight1 = stack_and_dev(qweight1_l).contiguous()
# scales1 = stack_and_dev(scales1_l)
# global_scale1 = None
# g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
# zeros1 = None
# sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
# marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
# b_bias2_l = []
# w_ref2_l = []
# qweight2_l = []
# scales2_l = []
# g_idx2_l = []
# sort_indices2_l = []
# for i in range(w2.shape[0]):
# test_perm = torch.randperm(n)
# w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
# marlin_quantize(w2[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# g_idx2_l.append(g_idx2)
# sort_indices2_l.append(sort_indices2)
# b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
# w_ref2 = stack_and_dev(w_ref2_l)
# qweight2 = stack_and_dev(qweight2_l).contiguous()
# scales2 = stack_and_dev(scales2_l)
# global_scale2 = None
# g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
# zeros2 = None
# sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
# marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
# score = torch.randn((m, e), device="cuda", dtype=dtype)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
# topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
# with set_current_vllm_config(vllm_config):
# torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
# b_bias2)
# marlin_output = torch.ops.vllm.fused_marlin_moe(
# a,
# qweight1,
# qweight2,
# marlin_bias1,
# marlin_bias2,
# scales1,
# scales2,
# score,
# topk_weights,
# topk_ids,
# global_num_experts=e,
# expert_map=None,
# global_scale1=global_scale1,
# global_scale2=global_scale2,
# g_idx1=g_idx1,
# g_idx2=g_idx2,
# sort_indices1=sort_indices1,
# sort_indices2=sort_indices2,
# w1_zeros=zeros1,
# w2_zeros=zeros2,
# quant_type_id=quant_type.id,
# is_k_full=is_k_full)
# torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
......@@ -855,19 +855,19 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad))
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
# @pytest.mark.parametrize("m", [1, 33, 64, 222])
# @pytest.mark.parametrize("topk", TOP_KS)
# @pytest.mark.parametrize("k", [128, 511, 1024])
# @pytest.mark.parametrize("dtype",
# [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
# input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
# actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
torch.ops._moe_C.moe_sum(input, actual)
# expected = input.sum(dim=1)
# torch.ops._moe_C.moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
# torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
# opcheck(torch.ops._moe_C.moe_sum, (input, actual))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment