Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
// SPDX-License-Identifier: MIT
#include "grouped_gemm_ck.h"
#include "rocm_ops.hpp"
#include <pybind11/stl.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ck_grouped_gemm", &ck_grouped_gemm, py::arg("a_tensors"), py::arg("b_tensors"));
m.def("ck_grouped_gemm_out",
&ck_grouped_gemm_out,
py::arg("a_tensors"),
py::arg("b_tensors"),
py::arg("c_tensors"));
}
// SPDX-License-Identifier: MIT
#include "rocm_ops.hpp"
#include "mhc.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
MHC_PYBIND;
}
...@@ -87,7 +87,9 @@ if __name__ == "__main__": ...@@ -87,7 +87,9 @@ if __name__ == "__main__":
print(f">>> ERROR: {args.input_file} does not exist. Exiting") print(f">>> ERROR: {args.input_file} does not exist. Exiting")
exit(1) exit(1)
shapes = pd.read_csv(args.input_file).fillna("") shapes = pd.read_csv(args.input_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
shapes[int_cols] = shapes[int_cols].astype(int)
for i in range(len(shapes)): for i in range(len(shapes)):
ds = shapes.iloc[i] ds = shapes.iloc[i]
moe_tuner.add_moe( moe_tuner.add_moe(
......
...@@ -11,12 +11,17 @@ class MoeTuner: ...@@ -11,12 +11,17 @@ class MoeTuner:
def __init__(self, indtype, tuned_file=None, mp=1): def __init__(self, indtype, tuned_file=None, mp=1):
self.arch = get_gfx() self.arch = get_gfx()
self.moe_pro_df = pd.DataFrame(columns=["quant_type", "indtype", "token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"]) self.moe_pro_df = pd.DataFrame(columns=["quant_type", "indtype", "token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"])
self._int_cols = ["token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"]
self.indtype = indtype self.indtype = indtype
self.tuned_file = tuned_file self.tuned_file = tuned_file
self.mp = mp self.mp = mp
if Path(tuned_file).is_file(): if Path(tuned_file).is_file():
self.tuned_shapes = pd.read_csv(tuned_file).fillna("") self.tuned_shapes = pd.read_csv(tuned_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
for c in int_cols:
if c in self.tuned_shapes.columns:
self.tuned_shapes[c] = self.tuned_shapes[c].astype(int)
else: else:
self.tuned_shapes = None self.tuned_shapes = None
...@@ -40,13 +45,13 @@ class MoeTuner: ...@@ -40,13 +45,13 @@ class MoeTuner:
entry = { entry = {
"quant_type": [quant_type], "quant_type": [quant_type],
"indtype": [indtype_str], "indtype": [indtype_str],
"token": [token], "token": [int(token)],
"inter_dim": [inter_dim], "inter_dim": [int(inter_dim)],
"model_dim": [model_dim], "model_dim": [int(model_dim)],
"expert": [expert], "expert": [int(expert)],
"topk": [topk], "topk": [int(topk)],
"q_size_n": [q_size_n], "q_size_n": [int(q_size_n)],
"q_size_k": [q_size_k] "q_size_k": [int(q_size_k)]
} }
df = pd.DataFrame(entry) df = pd.DataFrame(entry)
self.moe_pro_df = pd.concat([self.moe_pro_df, df], ignore_index=True) self.moe_pro_df = pd.concat([self.moe_pro_df, df], ignore_index=True)
......
...@@ -77,6 +77,7 @@ def test_get_config(m, k, n, e, topk, dtype): ...@@ -77,6 +77,7 @@ def test_get_config(m, k, n, e, topk, dtype):
assert moe_cfg.solution_type in ( assert moe_cfg.solution_type in (
MoeSolutionType.ASM, MoeSolutionType.ASM,
MoeSolutionType.TRITON, MoeSolutionType.TRITON,
MoeSolutionType.CK,
), f"Unexpected solution_type: {moe_cfg.solution_type}" ), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W16A16 assert moe_cfg.quant_type == MoeQuantType.W16A16
aiter.logger.info( aiter.logger.info(
...@@ -126,7 +127,11 @@ def _run_aiter_moe_perf(hidden_states, ...@@ -126,7 +127,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map, expert_map,
routed_scaling_factor, routed_scaling_factor,
): ):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp, if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor) a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
...@@ -192,8 +197,9 @@ def test_aiter_moe_w16a16(m, k, n, e, topk, dtype, inplace, routed_scaling_facto ...@@ -192,8 +197,9 @@ def test_aiter_moe_w16a16(m, k, n, e, topk, dtype, inplace, routed_scaling_facto
# Non-quantized bf16 matmul accumulation order differs between torch and # Non-quantized bf16 matmul accumulation order differs between torch and
# the fused triton/asm kernels, so we need a relaxed atol (matching # the fused triton/asm kernels, so we need a relaxed atol (matching
# test_moe_w16a16.py which uses atol=1 for torch vs triton). # test_moe_w16a16.py which uses atol=1 for torch vs triton).
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg) check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us} ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -247,14 +253,19 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype): ...@@ -247,14 +253,19 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, " msg = (f"[w16a16_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}") f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg) check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0 uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return { return {
"m": m, "m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us, "asm_us": asm_us,
"shuffle_us": shuffle_us, "shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}", "shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
} }
...@@ -265,7 +276,9 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype): ...@@ -265,7 +276,9 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
if __name__ == "__main__": if __name__ == "__main__":
dtype = dtypes.bf16 dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_part3_shuffle.csv"
# for asm solution in gfx936, only support the following fonfigurations, # for asm solution in gfx936, only support the following fonfigurations,
# check tuned_fmoe_asm.csv for details. # check tuned_fmoe_asm.csv for details.
# otherwise, the interface will return triton solution. # otherwise, the interface will return triton solution.
...@@ -298,6 +311,8 @@ if __name__ == "__main__": ...@@ -298,6 +311,8 @@ if __name__ == "__main__":
if df: if df:
df = pd.DataFrame(df) df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}") aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16) --- # --- Part 3: test ASM shuffle vs non-shuffle (w16a16) ---
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
...@@ -312,11 +327,5 @@ if __name__ == "__main__": ...@@ -312,11 +327,5 @@ if __name__ == "__main__":
if df_shuffle: if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle) df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary:\n{df_shuffle}") aiter.logger.info(f"shuffle summary:\n{df_shuffle}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
# --- Combined summary --- aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
# if len(df) > 0 and len(df_shuffle) > 0:
# df_combined = pd.merge(
# pd.DataFrame(df), pd.DataFrame(df_shuffle),
# on="m", how="outer", suffixes=("", "_shuffle"),
# )
# aiter.logger.info(f"combined summary:\n{df_combined}")
...@@ -63,7 +63,7 @@ def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids): ...@@ -63,7 +63,7 @@ def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated) # Weight preparation helpers (W16A16 non-gated)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype): def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=False):
"""Build all tensors needed to run a non-gated w16a16 MOE test. """Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k]. Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
...@@ -74,8 +74,12 @@ def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype): ...@@ -74,8 +74,12 @@ def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1) if asm_backend:
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2) w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
else:
w1_shuffle = w1
w2_shuffle = w2
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True) topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
...@@ -165,7 +169,11 @@ def _run_aiter_moe_perf(hidden_states, ...@@ -165,7 +169,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map, expert_map,
routed_scaling_factor, routed_scaling_factor,
): ):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp, if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor) a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
...@@ -197,7 +205,7 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin ...@@ -197,7 +205,7 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
f"backend={backend}" f"backend={backend}"
) )
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype) data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=(backend == MoeSolutionType.ASM))
# Torch reference # Torch reference
ref_out, _ = _run_torch_ref( ref_out, _ = _run_torch_ref(
...@@ -230,8 +238,9 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin ...@@ -230,8 +238,9 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, " msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}") f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg) check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us} ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -279,27 +288,33 @@ def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype): ...@@ -279,27 +288,33 @@ def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, " msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}") f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg) check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0 uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return { return {
"m": m, "m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us, "asm_us": asm_us,
"shuffle_us": shuffle_us, "shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}", "shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
} }
if __name__ == "__main__": if __name__ == "__main__":
dtype = dtypes.bf16 dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_nogate_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_nogate_part3_shuffle.csv"
# Nemotron-style MoE parameters (non-gated, ReLU²) # Nemotron-style MoE parameters (non-gated, ReLU²)
e = 256 e = 512
topk = 8 topk = 22
k = 3072 # model_dim / hidden_size k = 1024 # model_dim / hidden_size
n = 128 # intermediate_size (NOT multiplied by 2) n = 336 # intermediate_size (NOT multiplied by 2)
inplace = False inplace = True
routed_scaling_factor = 1.0 routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) --- # --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
...@@ -324,17 +339,24 @@ if __name__ == "__main__": ...@@ -324,17 +339,24 @@ if __name__ == "__main__":
if df: if df:
df = pd.DataFrame(df) df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}") aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) --- # --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2") aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
df_shuffle = [] if df.empty or not any(df["backend"] == MoeSolutionType.ASM):
for m in test_tokens: aiter.logger.info("Skipping Part 3 since ASM backend was not selected in Part 2")
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype) else:
if ret is not None: df_shuffle = []
df_shuffle.append(ret) for m in test_tokens:
if df_shuffle: ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
df_shuffle = pd.DataFrame(df_shuffle) if ret is not None:
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}") df_shuffle.append(ret)
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
\ No newline at end of file
# Test for get_aiter_moe_config_w4a16 and aiter_moe_w4a16 # Test for get_aiter_moe_config_w4a16 and aiter_moe_w4a16
import torch import torch
import itertools import itertools
import pandas as pd import pandas as pd
from typing import Optional, List from typing import Optional, List
from op_tests.utility.scalar_type import scalar_types try:
from op_tests.utility.utils import quantize_weights from op_tests.utility.scalar_type import scalar_types
from aiter.fused_moe import fused_topk, torch_moe from op_tests.utility.utils import quantize_weights
from aiter import ActivationType, dtypes except ModuleNotFoundError:
from aiter.test_common import checkAllclose, perftest import sys
from aiter.moe import ( from pathlib import Path
get_aiter_moe_config,
aiter_moe, _ROOT = Path(__file__).resolve().parents[2]
MoeSolutionType, if str(_ROOT) not in sys.path:
MoeQuantType, sys.path.insert(0, str(_ROOT))
)
from aiter.ops.shuffle import w4a16_marlin_weight_1, w4a16_marlin_weight_2 from op_tests.utility.scalar_type import scalar_types
import aiter from op_tests.utility.utils import quantize_weights
from aiter.fused_moe import fused_topk, torch_moe
torch.set_default_device("cuda") from aiter import ActivationType, dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
# --------------------------------------------------------------------------- get_aiter_moe_config,
# Weight quantization helpers (adapted from test_moe_wna16.py) aiter_moe,
# --------------------------------------------------------------------------- MoeSolutionType,
MoeQuantType,
def _quantize_w4a16_weights(w_fp, group_size, has_zp, pack_for_backend): )
"""Quantize a single expert weight matrix to int4. from aiter.ops.shuffle import w4a16_marlin_weight_1, w4a16_marlin_weight_2
import aiter
Args:
w_fp: Floating-point weight ``[out_features, in_features]``. torch.set_default_device("cuda")
group_size: Quantization group size along K.
has_zp: Whether to produce zero-points.
pack_for_backend: ``"triton"`` / ``"asm"`` / ``"moe_c"`` – determines # ---------------------------------------------------------------------------
the packing and layout convention for qweight / qzeros. # Weight quantization helpers (adapted from test_moe_wna16.py)
# ---------------------------------------------------------------------------
Returns:
(weight_ref, qweight, scales, qzeros_or_None) def _quantize_w4a16_weights(w_fp, group_size, has_zp, pack_for_backend):
""" """Quantize a single expert weight matrix to int4.
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
weight, qweight, scales, qzeros = quantize_weights( Args:
w_fp.T, quant_type, group_size, has_zp, False) w_fp: Floating-point weight ``[out_features, in_features]``.
weight = weight.T group_size: Quantization group size along K.
qweight = qweight.T.contiguous().to(torch.uint8) has_zp: Whether to produce zero-points.
scales = scales.T pack_for_backend: ``"triton"`` / ``"asm"`` / ``"moe_c"`` – determines
the packing and layout convention for qweight / qzeros.
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8) Returns:
(weight_ref, qweight, scales, qzeros_or_None)
# int4: pack two nibbles into one byte """
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
if has_zp: weight, qweight, scales, qzeros = quantize_weights(
if pack_for_backend == "asm": w_fp.T, quant_type, group_size, has_zp, False)
qzeros = qzeros[:, 1::2] * 16 + qzeros[:, ::2] weight = weight.T
else: qweight = qweight.T.contiguous().to(torch.uint8)
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] scales = scales.T
return weight, qweight, scales, qzeros if has_zp else None if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
def prepare_w4a16_inputs(m, k, n, e, topk, group_size, has_zp, dtype, # int4: pack two nibbles into one byte
backend): qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
"""Build all tensors needed to run a w4a16 MOE test. if has_zp:
if pack_for_backend == "asm":
Returns a dict of tensors keyed by name. qzeros = qzeros[:, 1::2] * 16 + qzeros[:, ::2]
""" else:
pack_factor = 2 # int4 qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10 return weight, qweight, scales, qzeros if has_zp else None
w1_fp = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2_fp = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype) def prepare_w4a16_inputs(m, k, n, e, topk, group_size, has_zp, dtype,
backend):
# Allocate packed weight storage """Build all tensors needed to run a w4a16 MOE test.
w1_qweight = torch.empty((e, 2 * n, k // pack_factor), device="cuda",
dtype=torch.uint8) Returns a dict of tensors keyed by name.
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", """
dtype=torch.uint8) pack_factor = 2 # int4
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda",
dtype=dtype) input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w2_scales = torch.empty((e, k, n // group_size), device="cuda", w1_fp = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
dtype=dtype) w2_fp = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if has_zp:
if backend == "asm": # Allocate packed weight storage
w1_qzeros = torch.empty( w1_qweight = torch.empty((e, 2 * n, k // pack_factor), device="cuda",
(e, 2 * n, k // group_size // pack_factor), device="cuda", dtype=torch.uint8)
dtype=torch.uint8) w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda",
w2_qzeros = torch.empty( dtype=torch.uint8)
(e, k, n // group_size // pack_factor), device="cuda", w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda",
dtype=torch.uint8) dtype=dtype)
else: w2_scales = torch.empty((e, k, n // group_size), device="cuda",
w1_qzeros = torch.empty( dtype=dtype)
(e, 2 * n // pack_factor, k // group_size), device="cuda",
dtype=torch.uint8) if has_zp:
w2_qzeros = torch.empty( if backend == "asm":
(e, k // pack_factor, n // group_size), device="cuda", w1_qzeros = torch.empty(
dtype=torch.uint8) (e, 2 * n, k // group_size // pack_factor), device="cuda",
else: dtype=torch.uint8)
w1_qzeros = None w2_qzeros = torch.empty(
w2_qzeros = None (e, k, n // group_size // pack_factor), device="cuda",
dtype=torch.uint8)
w1_ref = w1_fp.clone() else:
w2_ref = w2_fp.clone() w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda",
for i in range(e * 2): dtype=torch.uint8)
expert_id = i % e w2_qzeros = torch.empty(
if i // e == 0: (e, k // pack_factor, n // group_size), device="cuda",
w_fp_e, w_ref, w_qw, w_sc, w_zp = ( dtype=torch.uint8)
w1_fp, w1_ref, w1_qweight, w1_scales, w1_qzeros) else:
else: w1_qzeros = None
w_fp_e, w_ref, w_qw, w_sc, w_zp = ( w2_qzeros = None
w2_fp, w2_ref, w2_qweight, w2_scales, w2_qzeros)
weight, qweight, scales, qzeros = _quantize_w4a16_weights( w1_ref = w1_fp.clone()
w_fp_e[expert_id], group_size, has_zp, backend) w2_ref = w2_fp.clone()
w_ref[expert_id] = weight
w_qw[expert_id] = qweight for i in range(e * 2):
w_sc[expert_id] = scales expert_id = i % e
if has_zp and w_zp is not None: if i // e == 0:
w_zp[expert_id] = qzeros w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w1_fp, w1_ref, w1_qweight, w1_scales, w1_qzeros)
# For moe_c backend, apply marlin weight shuffle else:
if backend == "moe_c": w_fp_e, w_ref, w_qw, w_sc, w_zp = (
w1_qweight_final = w4a16_marlin_weight_1(w1_qweight) w2_fp, w2_ref, w2_qweight, w2_scales, w2_qzeros)
w2_qweight_final = w4a16_marlin_weight_2(w2_qweight) weight, qweight, scales, qzeros = _quantize_w4a16_weights(
w1_qweight_final = w1_qweight_final.view(-1).view( w_fp_e[expert_id], group_size, has_zp, backend)
torch.uint8).view(*w1_qweight.shape) w_ref[expert_id] = weight
w2_qweight_final = w2_qweight_final.view(-1).view( w_qw[expert_id] = qweight
torch.uint8).view(*w2_qweight.shape) w_sc[expert_id] = scales
else: if has_zp and w_zp is not None:
w1_qweight_final = w1_qweight w_zp[expert_id] = qzeros
w2_qweight_final = w2_qweight
# For moe_c backend, apply marlin weight shuffle
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True) if backend == "moe_c":
w1_qweight_final = w4a16_marlin_weight_1(w1_qweight)
return { w2_qweight_final = w4a16_marlin_weight_2(w2_qweight)
"input": input_tensor, w1_qweight_final = w1_qweight_final.view(-1).view(
"w1_ref": w1_ref, torch.uint8).view(*w1_qweight.shape)
"w2_ref": w2_ref, w2_qweight_final = w2_qweight_final.view(-1).view(
"w1_qweight": w1_qweight_final, torch.uint8).view(*w2_qweight.shape)
"w2_qweight": w2_qweight_final, else:
"w1_scales": w1_scales, w1_qweight_final = w1_qweight
"w2_scales": w2_scales, w2_qweight_final = w2_qweight
"w1_qzeros": w1_qzeros,
"w2_qzeros": w2_qzeros, topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
"topk_weights": topk_weights,
"topk_ids": topk_ids, return {
"score": score, "input": input_tensor,
} "w1_ref": w1_ref,
"w2_ref": w2_ref,
"w1_qweight": w1_qweight_final,
# --------------------------------------------------------------------------- "w2_qweight": w2_qweight_final,
# Test: get_aiter_moe_config (w4a16) "w1_scales": w1_scales,
# --------------------------------------------------------------------------- "w2_scales": w2_scales,
"w1_qzeros": w1_qzeros,
def test_get_config(m, k, n, e, topk, group_size, dtype): "w2_qzeros": w2_qzeros,
"""Test that get_aiter_moe_config returns a valid w4a16 config or "topk_weights": topk_weights,
gracefully reports no-solution.""" "topk_ids": topk_ids,
N1 = 2 * n # gate + up "score": score,
N2 = k # down / hidden_size }
K = k # model dimension (uncompressed)
status, moe_cfg = get_aiter_moe_config( # ---------------------------------------------------------------------------
M=m, E=e, N1=N1, N2=N2, K=K, # Test: get_aiter_moe_config (w4a16)
top_k=topk, block_size=group_size, dtype=dtype, # ---------------------------------------------------------------------------
quant_type=MoeQuantType.W4A16,
) def test_get_config(m, k, n, e, topk, group_size, dtype):
"""Test that get_aiter_moe_config returns a valid w4a16 config or
if status: gracefully reports no-solution."""
assert moe_cfg.solution_type is not None, \ N1 = 2 * n # gate + up
"status=True but solution_type is None" N2 = k # down / hidden_size
assert moe_cfg.config is not None, \ K = k # model dimension (uncompressed)
"status=True but config is None"
assert moe_cfg.solution_type in ( status, moe_cfg = get_aiter_moe_config(
MoeSolutionType.MOE_C, M=m, E=e, N1=N1, N2=N2, K=K,
MoeSolutionType.ASM, top_k=topk, block_size=group_size, dtype=dtype,
MoeSolutionType.TRITON, quant_type=MoeQuantType.W4A16,
), f"Unexpected solution_type: {moe_cfg.solution_type}" )
assert moe_cfg.quant_type == MoeQuantType.W4A16
aiter.logger.info( if status:
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, " assert moe_cfg.solution_type is not None, \
f"solution={moe_cfg.solution_type}, " "status=True but solution_type is None"
f"config keys={list(moe_cfg.config.keys())}" assert moe_cfg.config is not None, \
) "status=True but config is None"
else: assert moe_cfg.solution_type in (
assert moe_cfg.solution_type is None, \ MoeSolutionType.MOE_C,
"status=False but solution_type is not None" MoeSolutionType.ASM,
assert moe_cfg.config is None, \ MoeSolutionType.TRITON,
"status=False but config is not None" ), f"Unexpected solution_type: {moe_cfg.solution_type}"
aiter.logger.info( assert moe_cfg.quant_type == MoeQuantType.W4A16
f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, " aiter.logger.info(
f"no solution found (expected on unsupported configs)" f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
) f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
return status, moe_cfg )
else:
assert moe_cfg.solution_type is None, \
# --------------------------------------------------------------------------- "status=False but solution_type is not None"
# Test: aiter_moe end-to-end for w4a16 assert moe_cfg.config is None, \
# --------------------------------------------------------------------------- "status=False but config is not None"
aiter.logger.info(
@perftest(num_warmup=1, num_iters=2) f"[get_config_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids): f"no solution found (expected on unsupported configs)"
return torch_moe(hidden_states, w1, w2, topk_weights, topk_ids) )
return status, moe_cfg
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1, # ---------------------------------------------------------------------------
w2, # Test: aiter_moe end-to-end for w4a16
topk_weights, # ---------------------------------------------------------------------------
topk_ids,
moe_config, @perftest(num_warmup=1, num_iters=2)
inplace, def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
w1_scale, return torch_moe(hidden_states, w1, w2, topk_weights, topk_ids)
w2_scale,
w1_zp,
w2_zp, @perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
a1_scale, def _run_aiter_moe_perf(hidden_states,
a2_scale, w1,
block_shape, w2,
global_num_experts, topk_weights,
expert_map, topk_ids,
routed_scaling_factor, moe_config,
activation): inplace,
w1_scale,
mortal_input = hidden_states.clone() # 保证inplace操作的正确性 w2_scale,
w1_zp,
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor) a1_scale,
a2_scale,
block_shape,
def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor): global_num_experts,
"""End-to-end: get config -> run aiter_moe -> compare with torch expert_map,
reference.""" routed_scaling_factor,
N1 = 2 * n activation):
N2 = k
K = k if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
status, moe_cfg = get_aiter_moe_config( else:
M=m, E=e, N1=N1, N2=N2, K=K, mortal_input = hidden_states
top_k=topk, block_size=group_size, dtype=dtype,
quant_type=MoeQuantType.W4A16, return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
) a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor, output_dtype=hidden_states.dtype)
if not status:
aiter.logger.info( def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor):
f"[aiter_moe_w4a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: " """End-to-end: get config -> run aiter_moe -> compare with torch
f"no backend available" reference."""
) N1 = 2 * n
return None N2 = k
K = k
backend = moe_cfg.solution_type
aiter.logger.info( status, moe_cfg = get_aiter_moe_config(
f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, " M=m, E=e, N1=N1, N2=N2, K=K,
f"backend={backend}" top_k=topk, block_size=group_size, dtype=dtype,
) quant_type=MoeQuantType.W4A16,
)
data = prepare_w4a16_inputs(
m, k, n, e, topk, group_size, has_zp, dtype, backend) if not status:
aiter.logger.info(
# Torch reference f"[aiter_moe_w4a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
ref_out, _ = _run_torch_ref( f"no backend available"
data["input"], data["w1_ref"], data["w2_ref"], )
data["topk_weights"], data["topk_ids"], return None
)
backend = moe_cfg.solution_type
# generic aiter_moe dispatch with w4a16 config aiter.logger.info(
block_shape = [0, group_size] if group_size else None f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
aiter_us = 1.0 f"backend={backend}"
# aiter_out = aiter_moe( )
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"], data = prepare_w4a16_inputs(
w1=data["w1_qweight"], m, k, n, e, topk, group_size, has_zp, dtype, backend)
w2=data["w2_qweight"],
topk_weights=data["topk_weights"], # Torch reference
topk_ids=data["topk_ids"], ref_out, _ = _run_torch_ref(
moe_config=moe_cfg, data["input"], data["w1_ref"], data["w2_ref"],
inplace=inplace, data["topk_weights"], data["topk_ids"],
activation="silu", )
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"], # generic aiter_moe dispatch with w4a16 config
w1_zp=data["w1_qzeros"], block_shape = [0, group_size] if group_size else None
w2_zp=data["w2_qzeros"], aiter_us = 1.0
a1_scale=None, # aiter_out = aiter_moe(
a2_scale=None, aiter_out, aiter_us = _run_aiter_moe_perf(
block_shape=block_shape, hidden_states=data["input"],
global_num_experts=e, w1=data["w1_qweight"],
expert_map=None, w2=data["w2_qweight"],
routed_scaling_factor=routed_scaling_factor, topk_weights=data["topk_weights"],
) topk_ids=data["topk_ids"],
moe_config=moe_cfg,
msg = (f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, " inplace=inplace,
f"backend={backend}") activation="silu",
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg) w1_scale=data["w1_scales"],
return {"m": m, "backend": backend, "us": aiter_us} w2_scale=data["w2_scales"],
w1_zp=data["w1_qzeros"],
w2_zp=data["w2_qzeros"],
# --------------------------------------------------------------------------- a1_scale=None,
# Main: run tests a2_scale=None,
# --------------------------------------------------------------------------- block_shape=block_shape,
global_num_experts=e,
if __name__ == "__main__": expert_map=None,
routed_scaling_factor=routed_scaling_factor,
# ASM requires: (top_k == 8 && n == 256 && k == 7168) )
# Triton requires: (n == 2048 && [E ==8 || E == 16 || E == 32) ## (E == 2 || E == 4) 结果异常
# or (n == 256 && E == 256) msg = (f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
dtype = dtypes.bf16 f"backend={backend}")
group_size = 32 check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg)
has_zp = True ret_output = "passed" if check_ret == 0 else (1-check_ret)
e = 256 return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
topk = 8
k = 7168 # model_dim # ---------------------------------------------------------------------------
n = 256 # intermediate_size # Main: run tests
inplace = True # ---------------------------------------------------------------------------
routed_scaling_factor = 1.0
if __name__ == "__main__":
# --- Part 1: test get_aiter_moe_config (w4a16) ---
aiter.logger.info("=" * 60) # ASM requires: (top_k == 8 && n == 256 && k == 7168)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a16") # Triton requires: (n == 2048 && [E ==8 || E == 16 || E == 32) ## (E == 2 || E == 4) 结果异常
aiter.logger.info("=" * 60) # or (n == 256 && E == 256)
dtype = dtypes.bf16
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] group_size = 32
for m in test_tokens: has_zp = True
test_get_config(m, k, n, e, topk, group_size, dtype) e = 256
topk = 8
# --- Part 2: test aiter_moe end-to-end (w4a16) --- k = 7168 # model_dim
aiter.logger.info("=" * 60) n = 256 # intermediate_size
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w4a16") inplace = True
aiter.logger.info("=" * 60) routed_scaling_factor = 1.0
df = [] # --- Part 1: test get_aiter_moe_config (w4a16) ---
for m in test_tokens: aiter.logger.info("=" * 60)
ret = test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor) aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a16")
if ret is not None: aiter.logger.info("=" * 60)
df.append(ret)
if df: test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
df = pd.DataFrame(df) for m in test_tokens:
aiter.logger.info(f"summary:\n{df}") test_get_config(m, k, n, e, topk, group_size, dtype)
# --- Part 2: test aiter_moe end-to-end (w4a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w4a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("test_aiter_moe_with_config_w4a16.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
...@@ -437,8 +437,8 @@ def perchannel_w8a8_triton(input, ...@@ -437,8 +437,8 @@ def perchannel_w8a8_triton(input,
triton_moe_sum(output_triton2.view(*output_triton2.shape), triton_moe_sum(output_triton2.view(*output_triton2.shape),
out_hidden_states) out_hidden_states)
else: else:
aiter.moe_c_moe_sum(output_triton2.view(*output_triton2.shape), aiter.moe_c_moe_sum_opt_v2(output_triton2.view(*output_triton2.shape),
out_hidden_states,topk_ids) out_hidden_states,1.0)
# print("**************************************triton") # print("**************************************triton")
# print(out_hidden_states) # print(out_hidden_states)
...@@ -469,7 +469,7 @@ def _run_triton_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_ ...@@ -469,7 +469,7 @@ def _run_triton_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_
) )
return perchannel_w8a8_triton(hidden_states,w1,w2,w1_scale,w2_scale,topk_weights, topk_ids,sorted_token_ids, expert_ids, num_tokens_post_padded,dtype,True) return perchannel_w8a8_triton(hidden_states,w1,w2,w1_scale,w2_scale,topk_weights, topk_ids,sorted_token_ids, expert_ids, num_tokens_post_padded,dtype,True)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1) @perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph = True)
def _run_aiter_moe_perf( def _run_aiter_moe_perf(
hidden_states, hidden_states,
w1, w1,
...@@ -490,7 +490,10 @@ def _run_aiter_moe_perf( ...@@ -490,7 +490,10 @@ def _run_aiter_moe_perf(
expert_map, expert_map,
): ):
mortal_input = hidden_states.clone() #保证inplace操作的正确性 if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe( return aiter_moe(
mortal_input, mortal_input,
...@@ -826,7 +829,8 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype): ...@@ -826,7 +829,8 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None, expert_map=None,
) )
msg = f"[aiter_triton_w4a8] {m=}" msg = f"[aiter_triton_w4a8] {m=}"
checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg) triton_check_ret = checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg)
ret["triton_accuracy"] = "passed" if triton_check_ret == 0 else (1-triton_check_ret)
ret["aiter_triton_us"] = aiter_triton_us ret["aiter_triton_us"] = aiter_triton_us
aiter_us = 1.0 aiter_us = 1.0
...@@ -854,10 +858,13 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype): ...@@ -854,10 +858,13 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None, expert_map=None,
) )
print("aiter_out",aiter_out)
print("aiter_triton_out",aiter_triton_out)
msg = f"[aiter_moe_w4a8] {m=}, backend={moe_cfg.solution_type}" msg = f"[aiter_moe_w4a8] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg) check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
ret["moe_accuracy"] = "passed" if check_ret == 0 else (1-check_ret)
ret["aiter_moe_us"] = aiter_us ret["aiter_moe_us"] = aiter_us
return ret return ret
...@@ -868,14 +875,14 @@ if __name__ == "__main__": ...@@ -868,14 +875,14 @@ if __name__ == "__main__":
e = 256 e = 256
topk = 8 topk = 8
k = 6144 k = 7168
n = 2048 n = 256
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a8") aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a8")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
test_tokens = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048,4096,6144,8192,16384] test_tokens = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048,4096,6144,8192,16384]
# test_tokens = [1] # test_tokens = [2048]
for m in test_tokens: for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype) test_get_config(m, k, n, e, topk, dtype)
......
# Test for get_aiter_moe_config and aiter_moe with W8A16 (int8 weight, fp16/bf16 activation)
#
# Mirrors test_aiter_moe_with_config_w16a16.py but for the INT8_W8A16 quant
# type. The W8A16 ASM backend is not currently supported, so part 3
# (ASM shuffle vs non-shuffle) is omitted intentionally.
import torch
import pandas as pd
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import w8a16_marlin_weight_1, w8a16_marlin_weight_2
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Weight preparation helpers (W8A16 - int8 weights, fp16/bf16 activations)
# ---------------------------------------------------------------------------
def prepare_w8a16_inputs(m, k, n, e, topk, dtype):
"""Build all tensors needed to run a w8a16 (int8 weight) MOE test.
Layout:
- input: [m, k] fp16/bf16
- w1: [E, 2n, k] int8 (gate + up, gated activation)
- w2: [E, k, n] int8 (down)
- w1_scale: [E, 2n, 1] fp32 (per-output-channel)
- w2_scale: [E, k, 1] fp32 (per-output-channel)
- w1_marlin / w2_marlin: marlin-shuffled int8 weights consumed by the
moe_c W8A16 kernel.
The non-shuffled int8 weights together with the per-channel scales are
used to build the torch reference (dequantized to fp16/bf16).
"""
torch.manual_seed(0)
# Keep activations small to limit accumulation error in int8 dequant.
input_tensor = (torch.randn((m, k), device="cuda", dtype=dtype)) / 100
w1 = torch.randint(-127, 127, (e, 2 * n, k), device="cuda", dtype=torch.int8)
w2 = torch.randint(-127, 127, (e, k, n), device="cuda", dtype=torch.int8)
# Per-channel scales. The moe_c W8A16 marlin kernel expects scales in
# the activation dtype (fp16/bf16), so we keep them in `dtype`. Small
# magnitude keeps dequantized weights in a reasonable range.
w1_scale = (torch.randn((e, 2 * n, 1), device="cuda", dtype=dtype)).abs() * 1e-3
w2_scale = (torch.randn((e, k, 1), device="cuda", dtype=dtype)).abs() * 1e-3
# Marlin-shuffled weights for the moe_c W8A16 kernel.
w1_marlin = w8a16_marlin_weight_1(w1)
w2_marlin = w8a16_marlin_weight_2(w2)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1": w1,
"w2": w2,
"w1_marlin": w1_marlin,
"w2_marlin": w2_marlin,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w8a16)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, dtype):
"""Validate that get_aiter_moe_config returns a sane W8A16 config or
gracefully reports no-solution."""
N1 = 2 * n # gate + up
N2 = k # down / hidden_size
K = k # model dimension
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
# ASM backend is intentionally unsupported for W8A16.
assert moe_cfg.solution_type in (
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
), f"Unexpected solution_type for W8A16: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.INT8_W8A16
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"no solution found"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w8a16
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale, fc2_scale):
return torch_moe(
hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale=fc1_scale, fc2_scale=fc2_scale,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config,
inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts,
expert_map, routed_scaling_factor)
def test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe (W8A16) -> compare with
torch reference (dequantized weights)."""
N1 = 2 * n
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w8a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w8a16_inputs(m, k, n, e, topk, dtype)
# Pick weights based on the backend:
# - moe_c expects marlin-shuffled int8 weights
# - triton expects plain (E, N, K) int8 weights
if backend == MoeSolutionType.MOE_C:
w1_run, w2_run = data["w1_marlin"], data["w2_marlin"]
else:
w1_run, w2_run = data["w1"], data["w2"]
# Torch reference (dequantize int8 weights via per-channel scales)
ref_out, _ = _run_torch_ref(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
data["w1_scale"], data["w2_scale"],
)
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=w1_run,
w2=w2_run,
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scale"],
w2_scale=data["w2_scale"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None, # per-channel
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
# int8 dequant + bf16/fp16 accumulation order differs between the
# reference and the fused kernels; use a relaxed tolerance similar to
# the W16A16 test.
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.05, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
# Main: run tests
# ---------------------------------------------------------------------------
if __name__ == "__main__":
dtype = dtypes.fp16 # W8A16 moe_c kernel was validated with fp16
# Use shape that matches the tuned W8A16 moe_c configs under
# aiter/moe_c_configs (E=128, N=2048).
e = 128
topk = 8
k = 6144 # model_dim
n = 2048 # intermediate_size
inplace = False
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a16")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
# --- Part 2: test aiter_moe end-to-end (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace,
routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a16.csv", index=False)
\ No newline at end of file
# Test for get_aiter_moe_config and aiter_moe with w8a8 # Test for get_aiter_moe_config and aiter_moe with w8a8
import torch import torch
import pandas as pd import pandas as pd
from aiter.fused_moe import fused_topk from aiter.fused_moe import fused_topk
from aiter import dtypes from aiter import dtypes
from aiter.test_common import checkAllclose, perftest from aiter.test_common import checkAllclose, perftest
from aiter.moe import ( from aiter.moe import (
get_aiter_moe_config, get_aiter_moe_config,
aiter_moe, aiter_moe,
MoeSolutionType, MoeSolutionType,
MoeQuantType, MoeQuantType,
) )
import aiter import aiter
torch.set_default_device("cuda") torch.set_default_device("cuda")
def torch_moe_blockscale( def torch_moe_blockscale(
hidden_states, hidden_states,
w1, w1,
w2, w2,
topk_weight, topk_weight,
topk_ids, topk_ids,
dtype, dtype,
scale_blks, scale_blks,
fc1_scale, fc1_scale,
fc2_scale, fc2_scale,
): ):
compute_type = torch.float32 compute_type = torch.float32
hidden_states = hidden_states.to(compute_type) hidden_states = hidden_states.to(compute_type)
w1 = w1.to(compute_type) w1 = w1.to(compute_type)
w2 = w2.to(compute_type) w2 = w2.to(compute_type)
token_num, topk = topk_ids.shape token_num, topk = topk_ids.shape
expert, model_dim, inter_dim = w2.shape expert, model_dim, inter_dim = w2.shape
blk_n, blk_k = scale_blks blk_n, blk_k = scale_blks
nblk_n = inter_dim // blk_n nblk_n = inter_dim // blk_n
nblk_k = model_dim // blk_k nblk_k = model_dim // blk_k
fc1_scale_full = fc1_scale.view(-1, 1).repeat(1, blk_n * blk_k).view( fc1_scale_full = fc1_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, -1, nblk_k, blk_n, blk_k) expert, -1, nblk_k, blk_n, blk_k)
fc1_scale_full = fc1_scale_full.permute(0, 1, 3, 2, 4).contiguous().view( fc1_scale_full = fc1_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, 2 * inter_dim, model_dim) expert, 2 * inter_dim, model_dim)
fc2_scale_full = fc2_scale.view(-1, 1).repeat(1, blk_n * blk_k).view( fc2_scale_full = fc2_scale.view(-1, 1).repeat(1, blk_n * blk_k).view(
expert, model_dim // blk_n, inter_dim // blk_k, blk_n, blk_k) expert, model_dim // blk_n, inter_dim // blk_k, blk_n, blk_k)
fc2_scale_full = fc2_scale_full.permute(0, 1, 3, 2, 4).contiguous().view( fc2_scale_full = fc2_scale_full.permute(0, 1, 3, 2, 4).contiguous().view(
expert, model_dim, inter_dim) expert, model_dim, inter_dim)
w1 = w1 * fc1_scale_full w1 = w1 * fc1_scale_full
w2 = w2 * fc2_scale_full w2 = w2 * fc2_scale_full
hidden_states = hidden_states.view(token_num, 1, model_dim).repeat(1, topk, 1) hidden_states = hidden_states.view(token_num, 1, model_dim).repeat(1, topk, 1)
out = torch.zeros((token_num, topk, model_dim), dtype=compute_type, device=hidden_states.device) out = torch.zeros((token_num, topk, model_dim), dtype=compute_type, device=hidden_states.device)
for expert_id in range(w1.shape[0]): for expert_id in range(w1.shape[0]):
mask = topk_ids == expert_id mask = topk_ids == expert_id
if mask.sum() == 0: if mask.sum() == 0:
continue continue
sub_tokens = hidden_states[mask] sub_tokens = hidden_states[mask]
act_input = sub_tokens @ w1[expert_id].transpose(0, 1) act_input = sub_tokens @ w1[expert_id].transpose(0, 1)
gate, up = act_input.split([inter_dim, inter_dim], dim=-1) gate, up = act_input.split([inter_dim, inter_dim], dim=-1)
act_out = torch.nn.functional.silu(gate) * up act_out = torch.nn.functional.silu(gate) * up
out[mask] = act_out @ w2[expert_id].transpose(0, 1) out[mask] = act_out @ w2[expert_id].transpose(0, 1)
return (out * topk_weight.view(token_num, -1, 1)).sum(dim=1).to(dtype) return (out * topk_weight.view(token_num, -1, 1)).sum(dim=1).to(dtype)
@perftest(num_warmup=1, num_iters=2) @perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_shape, w1_scale, w2_scale): def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_shape, w1_scale, w2_scale):
return torch_moe_blockscale( return torch_moe_blockscale(
hidden_states, hidden_states,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
dtype, dtype,
block_shape, block_shape,
w1_scale, w1_scale,
w2_scale, w2_scale,
) )
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1) @perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf( def _run_aiter_moe_perf(
hidden_states, hidden_states,
w1, w1,
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
moe_config, moe_config,
inplace, inplace,
activation, activation,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp, w1_zp,
w2_zp, w2_zp,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
global_num_experts, global_num_experts,
expert_map, expert_map,
): ):
mortal_input = hidden_states.clone() # 保证inplace操作的正确性 if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
return aiter_moe( else:
mortal_input, mortal_input = hidden_states
w1,
w2, return aiter_moe(
topk_weights, mortal_input,
topk_ids, w1,
moe_config, w2,
inplace, topk_weights,
activation, topk_ids,
w1_scale, moe_config,
w2_scale, inplace,
w1_zp, activation,
w2_zp, w1_scale,
a1_scale, w2_scale,
a2_scale, w1_zp,
block_shape, w2_zp,
global_num_experts, a1_scale,
expert_map, a2_scale,
) block_shape,
global_num_experts,
expert_map,
def prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype): )
torch.manual_seed(0)
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8) def prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype):
int8_max, int8_min = int8_info.max, int8_info.min torch.manual_seed(0)
factor_for_scale = 1e-2
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10 int8_info = torch.iinfo(torch.int8)
w1_fp = (torch.rand((e, 2 * n, k), dtype=dtype, device="cuda") - 0.5) * 2 * int8_max int8_max, int8_min = int8_info.max, int8_info.min
w2_fp = (torch.rand((e, k, n), dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
w1_qweight = w1_fp.clamp(min=int8_min, max=int8_max).to(torch.int8) input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
w2_qweight = w2_fp.clamp(min=int8_min, max=int8_max).to(torch.int8) w1_fp = (torch.rand((e, 2 * n, k), dtype=dtype, device="cuda") - 0.5) * 2 * int8_max
w2_fp = (torch.rand((e, k, n), dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
block_n, block_k = block_shape w1_qweight = w1_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
n_tiles_w1 = (2 * n + block_n - 1) // block_n w2_qweight = w2_fp.clamp(min=int8_min, max=int8_max).to(torch.int8)
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k block_n, block_k = block_shape
k_tiles_w2 = (n + block_k - 1) // block_k n_tiles_w1 = (2 * n + block_n - 1) // block_n
w1_scales = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale n_tiles_w2 = (k + block_n - 1) // block_n
w2_scales = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale k_tiles_w1 = (k + block_k - 1) // block_k
score = torch.randn((m, e), dtype=dtype, device="cuda") k_tiles_w2 = (n + block_k - 1) // block_k
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, False) w1_scales = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
w2_scales = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
return { score = torch.randn((m, e), dtype=dtype, device="cuda")
"input": input_tensor, topk_weights, topk_ids = fused_topk(input_tensor, score, topk, False)
"w1_ref": w1_qweight,
"w2_ref": w2_qweight, return {
"w1_qweight": w1_qweight, "input": input_tensor,
"w2_qweight": w2_qweight, "w1_ref": w1_qweight,
"w1_scales": w1_scales, "w2_ref": w2_qweight,
"w2_scales": w2_scales, "w1_qweight": w1_qweight,
"topk_weights": topk_weights, "w2_qweight": w2_qweight,
"topk_ids": topk_ids, "w1_scales": w1_scales,
} "w2_scales": w2_scales,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
def test_get_config(m, k, n, e, topk, block_shape, dtype): }
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e, def test_get_config(m, k, n, e, topk, block_shape, dtype):
N1=2 * n, status, moe_cfg = get_aiter_moe_config(
N2=k, M=m,
K=k, E=e,
top_k=topk, N1=2 * n,
block_size=block_shape[1], N2=k,
dtype=dtype, K=k,
quant_type=MoeQuantType.W8A8, top_k=topk,
) block_size=block_shape[1],
dtype=dtype,
if status: quant_type=MoeQuantType.W8A8,
assert moe_cfg.quant_type == MoeQuantType.W8A8 )
assert moe_cfg.solution_type in (
MoeSolutionType.ASM, if status:
MoeSolutionType.MOE_C, assert moe_cfg.quant_type == MoeQuantType.W8A8
MoeSolutionType.TRITON, assert moe_cfg.solution_type in (
MoeSolutionType.CK, MoeSolutionType.ASM,
) MoeSolutionType.MOE_C,
assert moe_cfg.config is not None MoeSolutionType.TRITON,
aiter.logger.info( MoeSolutionType.CK,
f"[get_config_w8a8] {m=}, solution={moe_cfg.solution_type}, config keys={list(moe_cfg.config.keys())}" )
) assert moe_cfg.config is not None
else: aiter.logger.info(
assert moe_cfg.solution_type is None f"[get_config_w8a8] {m=}, solution={moe_cfg.solution_type}, config keys={list(moe_cfg.config.keys())}"
assert moe_cfg.config is None )
aiter.logger.info(f"[get_config_w8a8] {m=}, no solution found") else:
assert moe_cfg.solution_type is None
return status, moe_cfg assert moe_cfg.config is None
aiter.logger.info(f"[get_config_w8a8] {m=}, no solution found")
def test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype): return status, moe_cfg
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e, def test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype):
N1=2 * n, status, moe_cfg = get_aiter_moe_config(
N2=k, M=m,
K=k, E=e,
top_k=topk, N1=2 * n,
block_size=block_shape[1], N2=k,
dtype=dtype, K=k,
quant_type=MoeQuantType.W8A8, top_k=topk,
) block_size=block_shape[1],
dtype=dtype,
if not status: quant_type=MoeQuantType.W8A8,
aiter.logger.info(f"[aiter_moe_w8a8] SKIP {m=}: no backend available") )
return None
if not status:
data = prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype) aiter.logger.info(f"[aiter_moe_w8a8] SKIP {m=}: no backend available")
ref_out, _ = _run_torch_ref( return None
data["input"],
data["w1_ref"], data = prepare_w8a8_inputs(m, k, n, e, topk, block_shape, dtype)
data["w2_ref"], ref_out, _ = _run_torch_ref(
data["topk_weights"], data["input"],
data["topk_ids"], data["w1_ref"],
dtype, data["w2_ref"],
block_shape, data["topk_weights"],
data["w1_scales"], data["topk_ids"],
data["w2_scales"], dtype,
) block_shape,
data["w1_scales"],
aiter_us = 1.0 data["w2_scales"],
# aiter_out = aiter_moe( )
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"], aiter_us = 1.0
w1=data["w1_qweight"], # aiter_out = aiter_moe(
w2=data["w2_qweight"], aiter_out, aiter_us = _run_aiter_moe_perf(
topk_weights=data["topk_weights"], hidden_states=data["input"],
topk_ids=data["topk_ids"], w1=data["w1_qweight"],
moe_config=moe_cfg, w2=data["w2_qweight"],
inplace=False, topk_weights=data["topk_weights"],
activation="silu", topk_ids=data["topk_ids"],
w1_scale=data["w1_scales"], moe_config=moe_cfg,
w2_scale=data["w2_scales"], inplace=False,
w1_zp=None, activation="silu",
w2_zp=None, w1_scale=data["w1_scales"],
a1_scale=None, w2_scale=data["w2_scales"],
a2_scale=None, w1_zp=None,
block_shape=list(block_shape), w2_zp=None,
global_num_experts=e, a1_scale=None,
expert_map=None, a2_scale=None,
) block_shape=list(block_shape),
global_num_experts=e,
msg = f"[aiter_moe_w8a8] {m=}, backend={moe_cfg.solution_type}" expert_map=None,
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg) )
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us}
msg = f"[aiter_moe_w8a8] {m=}, backend={moe_cfg.solution_type}"
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
if __name__ == "__main__": passed = "passed" if check_ret == 0 else (1-check_ret)
dtype = dtypes.fp16 return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed}
block_shape = (128, 128)
e = 256
topk = 8 if __name__ == "__main__":
k = 7168 dtype = dtypes.fp16
n = 256 block_shape = (128, 128)
e = 256
aiter.logger.info("=" * 60) topk = 8
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8") k = 7168
aiter.logger.info("=" * 60) n = 256
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for m in test_tokens: aiter.logger.info("=" * 60)
test_get_config(m, k, n, e, topk, block_shape, dtype) aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8")
aiter.logger.info("=" * 60)
aiter.logger.info("=" * 60) test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8") for m in test_tokens:
aiter.logger.info("=" * 60) test_get_config(m, k, n, e, topk, block_shape, dtype)
df = []
for m in test_tokens: aiter.logger.info("=" * 60)
ret = test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype) aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8")
if ret is not None: aiter.logger.info("=" * 60)
df.append(ret) df = []
if df: for m in test_tokens:
df = pd.DataFrame(df) ret = test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype)
aiter.logger.info(f"summary:\n{df}") if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a8_blockwise.csv", index=False)
...@@ -32,7 +32,7 @@ def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids): ...@@ -32,7 +32,7 @@ def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
) )
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1) @perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph=True)
def _run_aiter_moe_perf( def _run_aiter_moe_perf(
hidden_states, hidden_states,
w1, w1,
...@@ -51,9 +51,13 @@ def _run_aiter_moe_perf( ...@@ -51,9 +51,13 @@ def _run_aiter_moe_perf(
block_shape, block_shape,
global_num_experts, global_num_experts,
expert_map, expert_map,
out_dtype,
): ):
mortal_input = hidden_states.clone() # 保证inplace操作的正确性 if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe( return aiter_moe(
mortal_input, mortal_input,
...@@ -73,6 +77,7 @@ def _run_aiter_moe_perf( ...@@ -73,6 +77,7 @@ def _run_aiter_moe_perf(
block_shape, block_shape,
global_num_experts, global_num_experts,
expert_map, expert_map,
output_dtype = out_dtype
) )
...@@ -86,12 +91,20 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -86,12 +91,20 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
""" """
torch.manual_seed(0) torch.manual_seed(0)
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10 if dtype == dtypes.fp8:
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda") input_tensor = torch.randn((m, k), dtype=dtypes.fp32, device="cuda") / 100
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda") w1_fp = torch.randn((e, 2 * n, k), dtype=dtypes.fp32, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtypes.fp32, device="cuda")
else:
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 100
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
if quant_type == MoeQuantType.FP8_W8A8: if quant_type == MoeQuantType.FP8_W8A8:
# FP8 channel-wise quantization via pertoken_quant # FP8 channel-wise quantization via pertoken_quant
if dtype == dtypes.fp8:
input_tensor_q, a1_scales = pertoken_quant(input_tensor, quant_dtype=dtypes.fp8)
w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8) w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8)
w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8) w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8)
else: else:
...@@ -106,7 +119,10 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -106,7 +119,10 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_scales = max_vals_w2 / 127.0 # (e, k, 1) w2_scales = max_vals_w2 / 127.0 # (e, k, 1)
w2_qweight = (w2_fp / max_vals_w2 * 127.0).round().clamp(min=-128, max=127).to(torch.int8) w2_qweight = (w2_fp / max_vals_w2 * 127.0).round().clamp(min=-128, max=127).to(torch.int8)
score = torch.randn((m, e), dtype=dtype, device="cuda") if dtype == dtypes.fp8:
score = torch.randn((m, e), dtype=dtypes.fp32, device="cuda")
else:
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True) topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
# moe_c backend needs layout-shuffled weights # moe_c backend needs layout-shuffled weights
...@@ -114,7 +130,8 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -114,7 +130,8 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_qweight_shuffle = moe_layout_shuffle_gemm2(w2_qweight).view(*w2_qweight.shape) w2_qweight_shuffle = moe_layout_shuffle_gemm2(w2_qweight).view(*w2_qweight.shape)
return { return {
"input": input_tensor, "input": input_tensor_q if dtype == dtypes.fp8 else input_tensor,
"a1_scales": a1_scales if dtype == dtypes.fp8 else None,
"w1_ref": w1_fp, "w1_ref": w1_fp,
"w2_ref": w2_fp, "w2_ref": w2_fp,
"w1_qweight": w1_qweight, "w1_qweight": w1_qweight,
...@@ -125,6 +142,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -125,6 +142,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
"w2_scales": w2_scales, "w2_scales": w2_scales,
"topk_weights": topk_weights, "topk_weights": topk_weights,
"topk_ids": topk_ids, "topk_ids": topk_ids,
"input_ref": input_tensor, # original fp32/bf16 input for reference
} }
...@@ -164,7 +182,7 @@ def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8): ...@@ -164,7 +182,7 @@ def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
return status, moe_cfg return status, moe_cfg
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8): def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type=MoeQuantType.W8A8, inplace=False):
"""End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8).""" """End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config( status, moe_cfg = get_aiter_moe_config(
M=m, M=m,
...@@ -174,7 +192,7 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -174,7 +192,7 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
K=k, K=k,
top_k=topk, top_k=topk,
block_size=0, block_size=0,
dtype=dtype, dtype=in_dtype,
quant_type=quant_type, quant_type=quant_type,
) )
...@@ -183,11 +201,11 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -183,11 +201,11 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available") aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None return None
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type) data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, in_dtype, quant_type)
# Torch reference uses original fp weights directly (no scales needed) # Torch reference uses pre-quantization input (fp32/bf16) since fp8 is unsupported
ref_out, _ = _run_torch_ref( ref_out, _ = _run_torch_ref(
data["input"], data["input_ref"],
data["w1_ref"], data["w1_ref"],
data["w2_ref"], data["w2_ref"],
data["topk_weights"], data["topk_weights"],
...@@ -209,27 +227,28 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant ...@@ -209,27 +227,28 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
topk_weights=data["topk_weights"], topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"], topk_ids=data["topk_ids"],
moe_config=moe_cfg, moe_config=moe_cfg,
inplace=True, inplace=inplace,
activation="silu", activation="silu",
w1_scale=data["w1_scales"], w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"], w2_scale=data["w2_scales"],
w1_zp=None, w1_zp=None,
w2_zp=None, w2_zp=None,
a1_scale=None, a1_scale=data["a1_scales"],
a2_scale=None, a2_scale=None,
block_shape=None, block_shape=None,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
out_dtype=out_dtype,
) )
print("aiter_out",aiter_out) # print("aiter_out",aiter_out)
print("ref_out",ref_out) # print("ref_out",ref_out)
# Compare in aiter_out's dtype since reference may be fp32 (from pre-quant input)
msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}" msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg) check_ret = checkAllclose(ref_out.to(aiter_out.dtype), aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us} passed = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed, }
if __name__ == "__main__": if __name__ == "__main__":
...@@ -239,36 +258,39 @@ if __name__ == "__main__": ...@@ -239,36 +258,39 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--quant", "--quant",
choices=["int8", "fp8"], choices=["int8", "fp8"],
default="int8", default="fp8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)", help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
) )
args = parser.parse_args() args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8 quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
inplace = False # in_dtype != out_dtype时,不能为True
# for moe_c backend, it does not support n=320 for now; # for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ; # for triton backend, it can run with n=320 in NMZ;
dtype = dtypes.bf16 in_dtype = dtypes.fp8
e = 256 out_dtype = dtypes.bf16
e = 192
topk = 8 topk = 8
k = 6144 k = 4096
n = 320 n = 384
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise") aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384] test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384]
for m in test_tokens: for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype, quant_type) test_get_config(m, k, n, e, topk, in_dtype, quant_type)
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise") aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
df = [] df = []
for m in test_tokens: for m in test_tokens:
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type) ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type, inplace)
if ret is not None: if ret is not None:
df.append(ret) df.append(ret)
if df: if df:
df = pd.DataFrame(df) df = pd.DataFrame(df)
df.to_csv("w8a8_channelwise.csv", index=False)
aiter.logger.info(f"summary:\n{df}") aiter.logger.info(f"summary:\n{df}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment