Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
// SPDX-License-Identifier: MIT
#include "grouped_gemm_ck.h"
#include "rocm_ops.hpp"
#include <pybind11/stl.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ck_grouped_gemm", &ck_grouped_gemm, py::arg("a_tensors"), py::arg("b_tensors"));
m.def("ck_grouped_gemm_out",
&ck_grouped_gemm_out,
py::arg("a_tensors"),
py::arg("b_tensors"),
py::arg("c_tensors"));
}
// SPDX-License-Identifier: MIT
#include "rocm_ops.hpp"
#include "mhc.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
MHC_PYBIND;
}
......@@ -87,7 +87,9 @@ if __name__ == "__main__":
print(f">>> ERROR: {args.input_file} does not exist. Exiting")
exit(1)
shapes = pd.read_csv(args.input_file).fillna("")
shapes = pd.read_csv(args.input_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
shapes[int_cols] = shapes[int_cols].astype(int)
for i in range(len(shapes)):
ds = shapes.iloc[i]
moe_tuner.add_moe(
......
......@@ -11,12 +11,17 @@ class MoeTuner:
def __init__(self, indtype, tuned_file=None, mp=1):
self.arch = get_gfx()
self.moe_pro_df = pd.DataFrame(columns=["quant_type", "indtype", "token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"])
self._int_cols = ["token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"]
self.indtype = indtype
self.tuned_file = tuned_file
self.mp = mp
if Path(tuned_file).is_file():
self.tuned_shapes = pd.read_csv(tuned_file).fillna("")
self.tuned_shapes = pd.read_csv(tuned_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
for c in int_cols:
if c in self.tuned_shapes.columns:
self.tuned_shapes[c] = self.tuned_shapes[c].astype(int)
else:
self.tuned_shapes = None
......@@ -40,13 +45,13 @@ class MoeTuner:
entry = {
"quant_type": [quant_type],
"indtype": [indtype_str],
"token": [token],
"inter_dim": [inter_dim],
"model_dim": [model_dim],
"expert": [expert],
"topk": [topk],
"q_size_n": [q_size_n],
"q_size_k": [q_size_k]
"token": [int(token)],
"inter_dim": [int(inter_dim)],
"model_dim": [int(model_dim)],
"expert": [int(expert)],
"topk": [int(topk)],
"q_size_n": [int(q_size_n)],
"q_size_k": [int(q_size_k)]
}
df = pd.DataFrame(entry)
self.moe_pro_df = pd.concat([self.moe_pro_df, df], ignore_index=True)
......
......@@ -77,6 +77,7 @@ def test_get_config(m, k, n, e, topk, dtype):
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
MoeSolutionType.CK,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W16A16
aiter.logger.info(
......@@ -126,7 +127,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
......@@ -192,8 +197,9 @@ def test_aiter_moe_w16a16(m, k, n, e, topk, dtype, inplace, routed_scaling_facto
# Non-quantized bf16 matmul accumulation order differs between torch and
# the fused triton/asm kernels, so we need a relaxed atol (matching
# test_moe_w16a16.py which uses atol=1 for torch vs triton).
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
......@@ -247,14 +253,19 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
}
......@@ -265,7 +276,9 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
if __name__ == "__main__":
dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_part3_shuffle.csv"
# for asm solution in gfx936, only support the following fonfigurations,
# check tuned_fmoe_asm.csv for details.
# otherwise, the interface will return triton solution.
......@@ -298,6 +311,8 @@ if __name__ == "__main__":
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16) ---
aiter.logger.info("=" * 60)
......@@ -312,11 +327,5 @@ if __name__ == "__main__":
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary:\n{df_shuffle}")
# --- Combined summary ---
# if len(df) > 0 and len(df_shuffle) > 0:
# df_combined = pd.merge(
# pd.DataFrame(df), pd.DataFrame(df_shuffle),
# on="m", how="outer", suffixes=("", "_shuffle"),
# )
# aiter.logger.info(f"combined summary:\n{df_combined}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
......@@ -63,7 +63,7 @@ def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids):
# ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated)
# ---------------------------------------------------------------------------
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=False):
"""Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
......@@ -74,8 +74,12 @@ def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2
score = torch.randn((m, e), device="cuda", dtype=dtype)
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
if asm_backend:
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
else:
w1_shuffle = w1
w2_shuffle = w2
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
......@@ -165,7 +169,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
......@@ -197,7 +205,7 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
f"backend={backend}"
)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=(backend == MoeSolutionType.ASM))
# Torch reference
ref_out, _ = _run_torch_ref(
......@@ -230,8 +238,9 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
......@@ -279,27 +288,33 @@ def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
}
if __name__ == "__main__":
dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_nogate_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_nogate_part3_shuffle.csv"
# Nemotron-style MoE parameters (non-gated, ReLU²)
e = 256
topk = 8
k = 3072 # model_dim / hidden_size
n = 128 # intermediate_size (NOT multiplied by 2)
inplace = False
e = 512
topk = 22
k = 1024 # model_dim / hidden_size
n = 336 # intermediate_size (NOT multiplied by 2)
inplace = True
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
......@@ -324,17 +339,24 @@ if __name__ == "__main__":
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
df_shuffle = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
if ret is not None:
df_shuffle.append(ret)
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
if df.empty or not any(df["backend"] == MoeSolutionType.ASM):
aiter.logger.info("Skipping Part 3 since ASM backend was not selected in Part 2")
else:
df_shuffle = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
if ret is not None:
df_shuffle.append(ret)
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
\ No newline at end of file
# Test for get_aiter_moe_config_w4a16 and aiter_moe_w4a16
import torch
import itertools
import pandas as pd
from typing import Optional, List
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
from aiter.fused_moe import fused_topk, torch_moe
from aiter import ActivationType, dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import w4a16_marlin_weight_1, w4a16_marlin_weight_2
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Weight quantization helpers (adapted from test_moe_wna16.py)
# ---------------------------------------------------------------------------
def _quantize_w4a16_weights(w_fp, group_size, has_zp, pack_for_backend):
"""Quantize a single expert weight matrix to int4.
Args:
w_fp: Floating-point weight ``[out_features, in_features]``.
group_size: Quantization group size along K.
has_zp: Whether to produce zero-points.
pack_for_backend: ``"triton"`` / ``"asm"`` / ``"moe_c"`` – determines
the packing and layout convention for qweight / qzeros.
Returns:
(weight_ref, qweight, scales, qzeros_or_None)
"""
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
weight, qweight, scales, qzeros = quantize_weights(
w_fp.T, quant_type, group_size, has_zp, False)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
# int4: pack two nibbles into one byte
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
if pack_for_backend == "asm":
qzeros = qzeros[:, 1::2] * 16 + qzeros[:, ::2]
else:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
return weight, qweight, scales, qzeros if has_zp else None
def prepare_w4a16_inputs(m, k, n, e, topk, group_size, has_zp, dtype,
backend):
"""Build all tensors needed to run a w4a16 MOE test.
Returns a dict of tensors keyed by name.
"""
pack_factor = 2 # int4
input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1_fp = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2_fp = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
# Allocate packed weight storage
w1_qweight = torch.empty((e, 2 * n, k // pack_factor), device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda",
dtype=dtype)
if has_zp:
if backend == "asm":
w1_qzeros = torch.empty(
(e, 2 * n, k // group_size // pack_factor), device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty(
(e, k, n // group_size // pack_factor), device="cuda",
dtype=torch.uint8)
else:
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda",
dtype=torch.uint8)
else:
w1_qzeros = None
w2_qzeros = None
w1_ref = w1_fp.clone()
w2_ref = w2_fp.clone()
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w1_fp, w1_ref, w1_qweight, w1_scales, w1_qzeros)
else:
w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w2_fp, w2_ref, w2_qweight, w2_scales, w2_qzeros)
weight, qweight, scales, qzeros = _quantize_w4a16_weights(
w_fp_e[expert_id], group_size, has_zp, backend)
w_ref[expert_id] = weight
w_qw[expert_id] = qweight
w_sc[expert_id] = scales
if has_zp and w_zp is not None:
w_zp[expert_id] = qzeros
# For moe_c backend, apply marlin weight shuffle
if backend == "moe_c":
w1_qweight_final = w4a16_marlin_weight_1(w1_qweight)
w2_qweight_final = w4a16_marlin_weight_2(w2_qweight)
w1_qweight_final = w1_qweight_final.view(-1).view(
torch.uint8).view(*w1_qweight.shape)
w2_qweight_final = w2_qweight_final.view(-1).view(
torch.uint8).view(*w2_qweight.shape)
else:
w1_qweight_final = w1_qweight
w2_qweight_final = w2_qweight
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1_ref": w1_ref,
"w2_ref": w2_ref,
"w1_qweight": w1_qweight_final,
"w2_qweight": w2_qweight_final,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"w1_qzeros": w1_qzeros,
"w2_qzeros": w2_qzeros,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w4a16)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, group_size, dtype):
"""Test that get_aiter_moe_config returns a valid w4a16 config or
gracefully reports no-solution."""
N1 = 2 * n # gate + up
N2 = k # down / hidden_size
K = k # model dimension (uncompressed)
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=group_size, dtype=dtype,
quant_type=MoeQuantType.W4A16,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
assert moe_cfg.solution_type in (
MoeSolutionType.MOE_C,
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W4A16
aiter.logger.info(
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"no solution found (expected on unsupported configs)"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w4a16
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
return torch_moe(hidden_states, w1, w2, topk_weights, topk_ids)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
activation):
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe -> compare with torch
reference."""
N1 = 2 * n
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=group_size, dtype=dtype,
quant_type=MoeQuantType.W4A16,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w4a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w4a16_inputs(
m, k, n, e, topk, group_size, has_zp, dtype, backend)
# Torch reference
ref_out, _ = _run_torch_ref(
data["input"], data["w1_ref"], data["w2_ref"],
data["topk_weights"], data["topk_ids"],
)
# generic aiter_moe dispatch with w4a16 config
block_shape = [0, group_size] if group_size else None
aiter_us = 1.0
# aiter_out = aiter_moe(
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1_qweight"],
w2=data["w2_qweight"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=data["w1_qzeros"],
w2_zp=data["w2_qzeros"],
a1_scale=None,
a2_scale=None,
block_shape=block_shape,
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
# ---------------------------------------------------------------------------
# Main: run tests
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# ASM requires: (top_k == 8 && n == 256 && k == 7168)
# Triton requires: (n == 2048 && [E ==8 || E == 16 || E == 32) ## (E == 2 || E == 4) 结果异常
# or (n == 256 && E == 256)
dtype = dtypes.bf16
group_size = 32
has_zp = True
e = 256
topk = 8
k = 7168 # model_dim
n = 256 # intermediate_size
inplace = True
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w4a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a16")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for m in test_tokens:
test_get_config(m, k, n, e, topk, group_size, dtype)
# --- Part 2: test aiter_moe end-to-end (w4a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w4a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
# Test for get_aiter_moe_config_w4a16 and aiter_moe_w4a16
import torch
import itertools
import pandas as pd
from typing import Optional, List
try:
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
except ModuleNotFoundError:
import sys
from pathlib import Path
_ROOT = Path(__file__).resolve().parents[2]
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
from aiter.fused_moe import fused_topk, torch_moe
from aiter import ActivationType, dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import w4a16_marlin_weight_1, w4a16_marlin_weight_2
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Weight quantization helpers (adapted from test_moe_wna16.py)
# ---------------------------------------------------------------------------
def _quantize_w4a16_weights(w_fp, group_size, has_zp, pack_for_backend):
"""Quantize a single expert weight matrix to int4.
Args:
w_fp: Floating-point weight ``[out_features, in_features]``.
group_size: Quantization group size along K.
has_zp: Whether to produce zero-points.
pack_for_backend: ``"triton"`` / ``"asm"`` / ``"moe_c"`` – determines
the packing and layout convention for qweight / qzeros.
Returns:
(weight_ref, qweight, scales, qzeros_or_None)
"""
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
weight, qweight, scales, qzeros = quantize_weights(
w_fp.T, quant_type, group_size, has_zp, False)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
# int4: pack two nibbles into one byte
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
if pack_for_backend == "asm":
qzeros = qzeros[:, 1::2] * 16 + qzeros[:, ::2]
else:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
return weight, qweight, scales, qzeros if has_zp else None
def prepare_w4a16_inputs(m, k, n, e, topk, group_size, has_zp, dtype,
backend):
"""Build all tensors needed to run a w4a16 MOE test.
Returns a dict of tensors keyed by name.
"""
pack_factor = 2 # int4
input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1_fp = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2_fp = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
# Allocate packed weight storage
w1_qweight = torch.empty((e, 2 * n, k // pack_factor), device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda",
dtype=dtype)
if has_zp:
if backend == "asm":
w1_qzeros = torch.empty(
(e, 2 * n, k // group_size // pack_factor), device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty(
(e, k, n // group_size // pack_factor), device="cuda",
dtype=torch.uint8)
else:
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda",
dtype=torch.uint8)
else:
w1_qzeros = None
w2_qzeros = None
w1_ref = w1_fp.clone()
w2_ref = w2_fp.clone()
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w1_fp, w1_ref, w1_qweight, w1_scales, w1_qzeros)
else:
w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w2_fp, w2_ref, w2_qweight, w2_scales, w2_qzeros)
weight, qweight, scales, qzeros = _quantize_w4a16_weights(
w_fp_e[expert_id], group_size, has_zp, backend)
w_ref[expert_id] = weight
w_qw[expert_id] = qweight
w_sc[expert_id] = scales
if has_zp and w_zp is not None:
w_zp[expert_id] = qzeros
# For moe_c backend, apply marlin weight shuffle
if backend == "moe_c":
w1_qweight_final = w4a16_marlin_weight_1(w1_qweight)
w2_qweight_final = w4a16_marlin_weight_2(w2_qweight)
w1_qweight_final = w1_qweight_final.view(-1).view(
torch.uint8).view(*w1_qweight.shape)
w2_qweight_final = w2_qweight_final.view(-1).view(
torch.uint8).view(*w2_qweight.shape)
else:
w1_qweight_final = w1_qweight
w2_qweight_final = w2_qweight
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1_ref": w1_ref,
"w2_ref": w2_ref,
"w1_qweight": w1_qweight_final,
"w2_qweight": w2_qweight_final,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"w1_qzeros": w1_qzeros,
"w2_qzeros": w2_qzeros,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w4a16)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, group_size, dtype):
"""Test that get_aiter_moe_config returns a valid w4a16 config or
gracefully reports no-solution."""
N1 = 2 * n # gate + up
N2 = k # down / hidden_size
K = k # model dimension (uncompressed)
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=group_size, dtype=dtype,
quant_type=MoeQuantType.W4A16,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
assert moe_cfg.solution_type in (
MoeSolutionType.MOE_C,
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W4A16
aiter.logger.info(
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"no solution found (expected on unsupported configs)"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w4a16
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
return torch_moe(hidden_states, w1, w2, topk_weights, topk_ids)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
activation):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor, output_dtype=hidden_states.dtype)
def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe -> compare with torch
reference."""
N1 = 2 * n
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=group_size, dtype=dtype,
quant_type=MoeQuantType.W4A16,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w4a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w4a16_inputs(
m, k, n, e, topk, group_size, has_zp, dtype, backend)
# Torch reference
ref_out, _ = _run_torch_ref(
data["input"], data["w1_ref"], data["w2_ref"],
data["topk_weights"], data["topk_ids"],
)
# generic aiter_moe dispatch with w4a16 config
block_shape = [0, group_size] if group_size else None
aiter_us = 1.0
# aiter_out = aiter_moe(
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1_qweight"],
w2=data["w2_qweight"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=data["w1_qzeros"],
w2_zp=data["w2_qzeros"],
a1_scale=None,
a2_scale=None,
block_shape=block_shape,
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
# Main: run tests
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# ASM requires: (top_k == 8 && n == 256 && k == 7168)
# Triton requires: (n == 2048 && [E ==8 || E == 16 || E == 32) ## (E == 2 || E == 4) 结果异常
# or (n == 256 && E == 256)
dtype = dtypes.bf16
group_size = 32
has_zp = True
e = 256
topk = 8
k = 7168 # model_dim
n = 256 # intermediate_size
inplace = True
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w4a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a16")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for m in test_tokens:
test_get_config(m, k, n, e, topk, group_size, dtype)
# --- Part 2: test aiter_moe end-to-end (w4a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w4a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("test_aiter_moe_with_config_w4a16.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
......@@ -437,8 +437,8 @@ def perchannel_w8a8_triton(input,
triton_moe_sum(output_triton2.view(*output_triton2.shape),
out_hidden_states)
else:
aiter.moe_c_moe_sum(output_triton2.view(*output_triton2.shape),
out_hidden_states,topk_ids)
aiter.moe_c_moe_sum_opt_v2(output_triton2.view(*output_triton2.shape),
out_hidden_states,1.0)
# print("**************************************triton")
# print(out_hidden_states)
......@@ -469,7 +469,7 @@ def _run_triton_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_
)
return perchannel_w8a8_triton(hidden_states,w1,w2,w1_scale,w2_scale,topk_weights, topk_ids,sorted_token_ids, expert_ids, num_tokens_post_padded,dtype,True)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph = True)
def _run_aiter_moe_perf(
hidden_states,
w1,
......@@ -490,7 +490,10 @@ def _run_aiter_moe_perf(
expert_map,
):
mortal_input = hidden_states.clone() #保证inplace操作的正确性
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
......@@ -826,7 +829,8 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None,
)
msg = f"[aiter_triton_w4a8] {m=}"
checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg)
triton_check_ret = checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg)
ret["triton_accuracy"] = "passed" if triton_check_ret == 0 else (1-triton_check_ret)
ret["aiter_triton_us"] = aiter_triton_us
aiter_us = 1.0
......@@ -854,10 +858,13 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None,
)
print("aiter_out",aiter_out)
print("aiter_triton_out",aiter_triton_out)
msg = f"[aiter_moe_w4a8] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
ret["moe_accuracy"] = "passed" if check_ret == 0 else (1-check_ret)
ret["aiter_moe_us"] = aiter_us
return ret
......@@ -868,14 +875,14 @@ if __name__ == "__main__":
e = 256
topk = 8
k = 6144
n = 2048
k = 7168
n = 256
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a8")
aiter.logger.info("=" * 60)
test_tokens = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048,4096,6144,8192,16384]
# test_tokens = [1]
# test_tokens = [2048]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
......
# Test for get_aiter_moe_config and aiter_moe with W8A16 (int8 weight, fp16/bf16 activation)
#
# Mirrors test_aiter_moe_with_config_w16a16.py but for the INT8_W8A16 quant
# type. The W8A16 ASM backend is not currently supported, so part 3
# (ASM shuffle vs non-shuffle) is omitted intentionally.
import torch
import pandas as pd
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import w8a16_marlin_weight_1, w8a16_marlin_weight_2
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Weight preparation helpers (W8A16 - int8 weights, fp16/bf16 activations)
# ---------------------------------------------------------------------------
def prepare_w8a16_inputs(m, k, n, e, topk, dtype):
"""Build all tensors needed to run a w8a16 (int8 weight) MOE test.
Layout:
- input: [m, k] fp16/bf16
- w1: [E, 2n, k] int8 (gate + up, gated activation)
- w2: [E, k, n] int8 (down)
- w1_scale: [E, 2n, 1] fp32 (per-output-channel)
- w2_scale: [E, k, 1] fp32 (per-output-channel)
- w1_marlin / w2_marlin: marlin-shuffled int8 weights consumed by the
moe_c W8A16 kernel.
The non-shuffled int8 weights together with the per-channel scales are
used to build the torch reference (dequantized to fp16/bf16).
"""
torch.manual_seed(0)
# Keep activations small to limit accumulation error in int8 dequant.
input_tensor = (torch.randn((m, k), device="cuda", dtype=dtype)) / 100
w1 = torch.randint(-127, 127, (e, 2 * n, k), device="cuda", dtype=torch.int8)
w2 = torch.randint(-127, 127, (e, k, n), device="cuda", dtype=torch.int8)
# Per-channel scales. The moe_c W8A16 marlin kernel expects scales in
# the activation dtype (fp16/bf16), so we keep them in `dtype`. Small
# magnitude keeps dequantized weights in a reasonable range.
w1_scale = (torch.randn((e, 2 * n, 1), device="cuda", dtype=dtype)).abs() * 1e-3
w2_scale = (torch.randn((e, k, 1), device="cuda", dtype=dtype)).abs() * 1e-3
# Marlin-shuffled weights for the moe_c W8A16 kernel.
w1_marlin = w8a16_marlin_weight_1(w1)
w2_marlin = w8a16_marlin_weight_2(w2)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1": w1,
"w2": w2,
"w1_marlin": w1_marlin,
"w2_marlin": w2_marlin,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w8a16)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, dtype):
"""Validate that get_aiter_moe_config returns a sane W8A16 config or
gracefully reports no-solution."""
N1 = 2 * n # gate + up
N2 = k # down / hidden_size
K = k # model dimension
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
# ASM backend is intentionally unsupported for W8A16.
assert moe_cfg.solution_type in (
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
), f"Unexpected solution_type for W8A16: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.INT8_W8A16
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"no solution found"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w8a16
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale, fc2_scale):
return torch_moe(
hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale=fc1_scale, fc2_scale=fc2_scale,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config,
inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts,
expert_map, routed_scaling_factor)
def test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe (W8A16) -> compare with
torch reference (dequantized weights)."""
N1 = 2 * n
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w8a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w8a16_inputs(m, k, n, e, topk, dtype)
# Pick weights based on the backend:
# - moe_c expects marlin-shuffled int8 weights
# - triton expects plain (E, N, K) int8 weights
if backend == MoeSolutionType.MOE_C:
w1_run, w2_run = data["w1_marlin"], data["w2_marlin"]
else:
w1_run, w2_run = data["w1"], data["w2"]
# Torch reference (dequantize int8 weights via per-channel scales)
ref_out, _ = _run_torch_ref(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
data["w1_scale"], data["w2_scale"],
)
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=w1_run,
w2=w2_run,
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scale"],
w2_scale=data["w2_scale"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None, # per-channel
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
# int8 dequant + bf16/fp16 accumulation order differs between the
# reference and the fused kernels; use a relaxed tolerance similar to
# the W16A16 test.
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.05, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
# Main: run tests
# ---------------------------------------------------------------------------
if __name__ == "__main__":
dtype = dtypes.fp16 # W8A16 moe_c kernel was validated with fp16
# Use shape that matches the tuned W8A16 moe_c configs under
# aiter/moe_c_configs (E=128, N=2048).
e = 128
topk = 8
k = 6144 # model_dim
n = 2048 # intermediate_size
inplace = False
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a16")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
# --- Part 2: test aiter_moe end-to-end (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace,
routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a16.csv", index=False)
\ No newline at end of file
# Test for get_aiter_moe_config and aiter_moe with w8a8
import torch
import pandas as pd
from aiter.fused_moe import fused_topk
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
import aiter
torch.set_default_device("cuda")
def torch_moe_blockscale(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
dtype,
scale_blks,
fc1_scale,
fc2_scale,
):
compute_type = torch.float32
hidden_states = hidden_states.to(compute_type)
w1 = w1.to(compute_type)
w2 = w2.to(compute_type)
token_num, topk = topk_ids.shape
expert, model_dim, inter_dim = w2.shape
blk_n, blk_k = scale_blks
nblk_n = inter_dim // blk_n
nblk_k = model_dim // blk_k
fc1_scale_full = fc1_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, -1, nblk_k, blk_n, blk_k)
fc1_scale_full = fc1_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, 2 * inter_dim, model_dim)
fc2_scale_full = fc2_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, model_dim // blk_n, inter_dim // blk_k, blk_n, blk_k)
fc2_scale_full = fc2_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, model_dim, inter_dim)
w1 = w1 * fc1_scale_full
w2 = w2 * fc2_scale_full
hidden_states = hidden_states.view(token_num, 1, model_dim).repeat(1, topk, 1)
out = torch.zeros((token_num, topk, model_dim), dtype=compute_type, device=hidden_states.device)
for expert_id in range(w1.shape[0]):
mask = topk_ids == expert_id
if mask.sum() == 0:
continue
sub_tokens = hidden_states[mask]
act_input = sub_tokens @ w1[expert_id].transpose(0, 1)
gate, up = act_input.split([inter_dim, inter_dim], dim=-1)
act_out = torch.nn.functional.silu(gate) * up
out[mask] = act_out @ w2[expert_id].transpose(0, 1)
return (out * topk_weight.view(token_num, -1, 1)).sum(dim=1).to(dtype)
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_shape, w1_scale, w2_scale):
return torch_moe_blockscale(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
dtype,
block_shape,
w1_scale,
w2_scale,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
):
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
return aiter_moe(
mortal_input,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
)
def prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype):
torch.manual_seed(0)
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
w1_fp = (torch.rand((e, 2 * n, k), dtype=dtype, device="cuda") - 0.5) * 2 * int8_max
w2_fp = (torch.rand((e, k, n), dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
w1_qweight = w1_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_qweight = w2_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_shape
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (n + block_k - 1) // block_k
w1_scales = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
w2_scales = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, False)
return {
"input": input_tensor,
"w1_ref": w1_qweight,
"w2_ref": w2_qweight,
"w1_qweight": w1_qweight,
"w2_qweight": w2_qweight,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
}
def test_get_config(m, k, n, e, topk, block_shape, dtype):
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=block_shape[1],
dtype=dtype,
quant_type=MoeQuantType.W8A8,
)
if status:
assert moe_cfg.quant_type == MoeQuantType.W8A8
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
MoeSolutionType.CK,
)
assert moe_cfg.config is not None
aiter.logger.info(
f"[get_config_w8a8] {m=}, solution={moe_cfg.solution_type}, config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
aiter.logger.info(f"[get_config_w8a8] {m=}, no solution found")
return status, moe_cfg
def test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype):
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=block_shape[1],
dtype=dtype,
quant_type=MoeQuantType.W8A8,
)
if not status:
aiter.logger.info(f"[aiter_moe_w8a8] SKIP {m=}: no backend available")
return None
data = prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype)
ref_out, _ = _run_torch_ref(
data["input"],
data["w1_ref"],
data["w2_ref"],
data["topk_weights"],
data["topk_ids"],
dtype,
block_shape,
data["w1_scales"],
data["w2_scales"],
)
aiter_us = 1.0
# aiter_out = aiter_moe(
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1_qweight"],
w2=data["w2_qweight"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=False,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=list(block_shape),
global_num_experts=e,
expert_map=None,
)
msg = f"[aiter_moe_w8a8] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us}
if __name__ == "__main__":
dtype = dtypes.fp16
block_shape = (128, 128)
e = 256
topk = 8
k = 7168
n = 256
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for m in test_tokens:
test_get_config(m, k, n, e, topk, block_shape, dtype)
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
# Test for get_aiter_moe_config and aiter_moe with w8a8
import torch
import pandas as pd
from aiter.fused_moe import fused_topk
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
import aiter
torch.set_default_device("cuda")
def torch_moe_blockscale(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
dtype,
scale_blks,
fc1_scale,
fc2_scale,
):
compute_type = torch.float32
hidden_states = hidden_states.to(compute_type)
w1 = w1.to(compute_type)
w2 = w2.to(compute_type)
token_num, topk = topk_ids.shape
expert, model_dim, inter_dim = w2.shape
blk_n, blk_k = scale_blks
nblk_n = inter_dim // blk_n
nblk_k = model_dim // blk_k
fc1_scale_full = fc1_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, -1, nblk_k, blk_n, blk_k)
fc1_scale_full = fc1_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, 2 * inter_dim, model_dim)
fc2_scale_full = fc2_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, model_dim // blk_n, inter_dim // blk_k, blk_n, blk_k)
fc2_scale_full = fc2_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, model_dim, inter_dim)
w1 = w1 * fc1_scale_full
w2 = w2 * fc2_scale_full
hidden_states = hidden_states.view(token_num, 1, model_dim).repeat(1, topk, 1)
out = torch.zeros((token_num, topk, model_dim), dtype=compute_type, device=hidden_states.device)
for expert_id in range(w1.shape[0]):
mask = topk_ids == expert_id
if mask.sum() == 0:
continue
sub_tokens = hidden_states[mask]
act_input = sub_tokens @ w1[expert_id].transpose(0, 1)
gate, up = act_input.split([inter_dim, inter_dim], dim=-1)
act_out = torch.nn.functional.silu(gate) * up
out[mask] = act_out @ w2[expert_id].transpose(0, 1)
return (out * topk_weight.view(token_num, -1, 1)).sum(dim=1).to(dtype)
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_shape, w1_scale, w2_scale):
return torch_moe_blockscale(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
dtype,
block_shape,
w1_scale,
w2_scale,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
)
def prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype):
torch.manual_seed(0)
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
w1_fp = (torch.rand((e, 2 * n, k), dtype=dtype, device="cuda") - 0.5) * 2 * int8_max
w2_fp = (torch.rand((e, k, n), dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
w1_qweight = w1_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_qweight = w2_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_shape
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (n + block_k - 1) // block_k
w1_scales = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
w2_scales = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, False)
return {
"input": input_tensor,
"w1_ref": w1_qweight,
"w2_ref": w2_qweight,
"w1_qweight": w1_qweight,
"w2_qweight": w2_qweight,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
}
def test_get_config(m, k, n, e, topk, block_shape, dtype):
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=block_shape[1],
dtype=dtype,
quant_type=MoeQuantType.W8A8,
)
if status:
assert moe_cfg.quant_type == MoeQuantType.W8A8
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
MoeSolutionType.CK,
)
assert moe_cfg.config is not None
aiter.logger.info(
f"[get_config_w8a8] {m=}, solution={moe_cfg.solution_type}, config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
aiter.logger.info(f"[get_config_w8a8] {m=}, no solution found")
return status, moe_cfg
def test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype):
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=block_shape[1],
dtype=dtype,
quant_type=MoeQuantType.W8A8,
)
if not status:
aiter.logger.info(f"[aiter_moe_w8a8] SKIP {m=}: no backend available")
return None
data = prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype)
ref_out, _ = _run_torch_ref(
data["input"],
data["w1_ref"],
data["w2_ref"],
data["topk_weights"],
data["topk_ids"],
dtype,
block_shape,
data["w1_scales"],
data["w2_scales"],
)
aiter_us = 1.0
# aiter_out = aiter_moe(
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1_qweight"],
w2=data["w2_qweight"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=False,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=list(block_shape),
global_num_experts=e,
expert_map=None,
)
msg = f"[aiter_moe_w8a8] {m=}, backend={moe_cfg.solution_type}"
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
passed = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed}
if __name__ == "__main__":
dtype = dtypes.fp16
block_shape = (128, 128)
e = 256
topk = 8
k = 7168
n = 256
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for m in test_tokens:
test_get_config(m, k, n, e, topk, block_shape, dtype)
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a8_blockwise.csv", index=False)
......@@ -32,7 +32,7 @@ def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph=True)
def _run_aiter_moe_perf(
hidden_states,
w1,
......@@ -51,9 +51,13 @@ def _run_aiter_moe_perf(
block_shape,
global_num_experts,
expert_map,
out_dtype,
):
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
......@@ -73,6 +77,7 @@ def _run_aiter_moe_perf(
block_shape,
global_num_experts,
expert_map,
output_dtype = out_dtype
)
......@@ -86,12 +91,20 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
"""
torch.manual_seed(0)
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
if dtype == dtypes.fp8:
input_tensor = torch.randn((m, k), dtype=dtypes.fp32, device="cuda") / 100
w1_fp = torch.randn((e, 2 * n, k), dtype=dtypes.fp32, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtypes.fp32, device="cuda")
else:
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 100
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
if quant_type == MoeQuantType.FP8_W8A8:
# FP8 channel-wise quantization via pertoken_quant
if dtype == dtypes.fp8:
input_tensor_q, a1_scales = pertoken_quant(input_tensor, quant_dtype=dtypes.fp8)
w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8)
w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8)
else:
......@@ -106,7 +119,10 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_scales = max_vals_w2 / 127.0 # (e, k, 1)
w2_qweight = (w2_fp / max_vals_w2 * 127.0).round().clamp(min=-128, max=127).to(torch.int8)
score = torch.randn((m, e), dtype=dtype, device="cuda")
if dtype == dtypes.fp8:
score = torch.randn((m, e), dtype=dtypes.fp32, device="cuda")
else:
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
# moe_c backend needs layout-shuffled weights
......@@ -114,7 +130,8 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_qweight_shuffle = moe_layout_shuffle_gemm2(w2_qweight).view(*w2_qweight.shape)
return {
"input": input_tensor,
"input": input_tensor_q if dtype == dtypes.fp8 else input_tensor,
"a1_scales": a1_scales if dtype == dtypes.fp8 else None,
"w1_ref": w1_fp,
"w2_ref": w2_fp,
"w1_qweight": w1_qweight,
......@@ -125,6 +142,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
"w2_scales": w2_scales,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"input_ref": input_tensor, # original fp32/bf16 input for reference
}
......@@ -164,7 +182,7 @@ def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
return status, moe_cfg
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type=MoeQuantType.W8A8, inplace=False):
"""End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
......@@ -174,7 +192,7 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
K=k,
top_k=topk,
block_size=0,
dtype=dtype,
dtype=in_dtype,
quant_type=quant_type,
)
......@@ -183,11 +201,11 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type)
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, in_dtype, quant_type)
# Torch reference uses original fp weights directly (no scales needed)
# Torch reference uses pre-quantization input (fp32/bf16) since fp8 is unsupported
ref_out, _ = _run_torch_ref(
data["input"],
data["input_ref"],
data["w1_ref"],
data["w2_ref"],
data["topk_weights"],
......@@ -209,27 +227,28 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=True,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a1_scale=data["a1_scales"],
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
out_dtype=out_dtype,
)
print("aiter_out",aiter_out)
print("ref_out",ref_out)
# print("aiter_out",aiter_out)
# print("ref_out",ref_out)
# Compare in aiter_out's dtype since reference may be fp32 (from pre-quant input)
msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us}
check_ret = checkAllclose(ref_out.to(aiter_out.dtype), aiter_out, rtol=0.01, atol=100, msg=msg)
passed = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed, }
if __name__ == "__main__":
......@@ -239,36 +258,39 @@ if __name__ == "__main__":
parser.add_argument(
"--quant",
choices=["int8", "fp8"],
default="int8",
default="fp8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
)
args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
inplace = False # in_dtype != out_dtype时,不能为True
# for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ;
dtype = dtypes.bf16
e = 256
in_dtype = dtypes.fp8
out_dtype = dtypes.bf16
e = 192
topk = 8
k = 6144
n = 320
k = 4096
n = 384
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype, quant_type)
test_get_config(m, k, n, e, topk, in_dtype, quant_type)
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type)
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type, inplace)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("w8a8_channelwise.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment