Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
// SPDX-License-Identifier: MIT
#include "grouped_gemm_ck.h"
#include "rocm_ops.hpp"
#include <pybind11/stl.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ck_grouped_gemm", &ck_grouped_gemm, py::arg("a_tensors"), py::arg("b_tensors"));
m.def("ck_grouped_gemm_out",
&ck_grouped_gemm_out,
py::arg("a_tensors"),
py::arg("b_tensors"),
py::arg("c_tensors"));
}
// SPDX-License-Identifier: MIT
#include "rocm_ops.hpp"
#include "mhc.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
MHC_PYBIND;
}
......@@ -87,7 +87,9 @@ if __name__ == "__main__":
print(f">>> ERROR: {args.input_file} does not exist. Exiting")
exit(1)
shapes = pd.read_csv(args.input_file).fillna("")
shapes = pd.read_csv(args.input_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
shapes[int_cols] = shapes[int_cols].astype(int)
for i in range(len(shapes)):
ds = shapes.iloc[i]
moe_tuner.add_moe(
......
......@@ -11,12 +11,17 @@ class MoeTuner:
def __init__(self, indtype, tuned_file=None, mp=1):
self.arch = get_gfx()
self.moe_pro_df = pd.DataFrame(columns=["quant_type", "indtype", "token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"])
self._int_cols = ["token", "inter_dim", "model_dim", "expert", "topk", "q_size_n", "q_size_k"]
self.indtype = indtype
self.tuned_file = tuned_file
self.mp = mp
if Path(tuned_file).is_file():
self.tuned_shapes = pd.read_csv(tuned_file).fillna("")
self.tuned_shapes = pd.read_csv(tuned_file).dropna(how='all').fillna(0)
int_cols = ['token', 'inter_dim', 'model_dim', 'expert', 'topk', 'q_size_n', 'q_size_k']
for c in int_cols:
if c in self.tuned_shapes.columns:
self.tuned_shapes[c] = self.tuned_shapes[c].astype(int)
else:
self.tuned_shapes = None
......@@ -40,13 +45,13 @@ class MoeTuner:
entry = {
"quant_type": [quant_type],
"indtype": [indtype_str],
"token": [token],
"inter_dim": [inter_dim],
"model_dim": [model_dim],
"expert": [expert],
"topk": [topk],
"q_size_n": [q_size_n],
"q_size_k": [q_size_k]
"token": [int(token)],
"inter_dim": [int(inter_dim)],
"model_dim": [int(model_dim)],
"expert": [int(expert)],
"topk": [int(topk)],
"q_size_n": [int(q_size_n)],
"q_size_k": [int(q_size_k)]
}
df = pd.DataFrame(entry)
self.moe_pro_df = pd.concat([self.moe_pro_df, df], ignore_index=True)
......
......@@ -77,6 +77,7 @@ def test_get_config(m, k, n, e, topk, dtype):
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
MoeSolutionType.CK,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W16A16
aiter.logger.info(
......@@ -126,7 +127,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
......@@ -192,8 +197,9 @@ def test_aiter_moe_w16a16(m, k, n, e, topk, dtype, inplace, routed_scaling_facto
# Non-quantized bf16 matmul accumulation order differs between torch and
# the fused triton/asm kernels, so we need a relaxed atol (matching
# test_moe_w16a16.py which uses atol=1 for torch vs triton).
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
......@@ -247,14 +253,19 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
}
......@@ -265,6 +276,8 @@ def test_aiter_moe_w16a16_shuffle(m, k, n, e, topk, dtype):
if __name__ == "__main__":
dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_part3_shuffle.csv"
# for asm solution in gfx936, only support the following fonfigurations,
# check tuned_fmoe_asm.csv for details.
......@@ -298,6 +311,8 @@ if __name__ == "__main__":
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16) ---
aiter.logger.info("=" * 60)
......@@ -312,11 +327,5 @@ if __name__ == "__main__":
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary:\n{df_shuffle}")
# --- Combined summary ---
# if len(df) > 0 and len(df_shuffle) > 0:
# df_combined = pd.merge(
# pd.DataFrame(df), pd.DataFrame(df_shuffle),
# on="m", how="outer", suffixes=("", "_shuffle"),
# )
# aiter.logger.info(f"combined summary:\n{df_combined}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
......@@ -63,7 +63,7 @@ def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids):
# ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated)
# ---------------------------------------------------------------------------
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=False):
"""Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
......@@ -74,8 +74,12 @@ def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2
score = torch.randn((m, e), device="cuda", dtype=dtype)
if asm_backend:
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
else:
w1_shuffle = w1
w2_shuffle = w2
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
......@@ -165,7 +169,11 @@ def _run_aiter_moe_perf(hidden_states,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
......@@ -197,7 +205,7 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
f"backend={backend}"
)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype, asm_backend=(backend == MoeSolutionType.ASM))
# Torch reference
ref_out, _ = _run_torch_ref(
......@@ -230,8 +238,9 @@ def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scalin
msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
......@@ -279,27 +288,33 @@ def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype):
msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
check_ret = checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"k": k,
"n": n,
"e": e,
"topk": topk,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
"accuracy": ret_output
}
if __name__ == "__main__":
dtype = dtypes.bf16
PART2_CSV_OUTPUT = "w16a16_nogate_part2_aiter_moe.csv"
PART3_CSV_OUTPUT = "w16a16_nogate_part3_shuffle.csv"
# Nemotron-style MoE parameters (non-gated, ReLU²)
e = 256
topk = 8
k = 3072 # model_dim / hidden_size
n = 128 # intermediate_size (NOT multiplied by 2)
inplace = False
e = 512
topk = 22
k = 1024 # model_dim / hidden_size
n = 336 # intermediate_size (NOT multiplied by 2)
inplace = True
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
......@@ -324,12 +339,17 @@ if __name__ == "__main__":
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}")
df.to_csv(PART2_CSV_OUTPUT, index=False)
aiter.logger.info(f"aiter_moe summary csv saved to {PART2_CSV_OUTPUT}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
if df.empty or not any(df["backend"] == MoeSolutionType.ASM):
aiter.logger.info("Skipping Part 3 since ASM backend was not selected in Part 2")
else:
df_shuffle = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
......@@ -338,3 +358,5 @@ if __name__ == "__main__":
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
df_shuffle.to_csv(PART3_CSV_OUTPUT, index=False)
aiter.logger.info(f"shuffle summary csv saved to {PART3_CSV_OUTPUT}")
\ No newline at end of file
......@@ -5,8 +5,19 @@ import itertools
import pandas as pd
from typing import Optional, List
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
try:
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
except ModuleNotFoundError:
import sys
from pathlib import Path
_ROOT = Path(__file__).resolve().parents[2]
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from op_tests.utility.scalar_type import scalar_types
from op_tests.utility.utils import quantize_weights
from aiter.fused_moe import fused_topk, torch_moe
from aiter import ActivationType, dtypes
from aiter.test_common import checkAllclose, perftest
......@@ -226,10 +237,13 @@ def _run_aiter_moe_perf(hidden_states,
routed_scaling_factor,
activation):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor, output_dtype=hidden_states.dtype)
def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, routed_scaling_factor):
......@@ -294,9 +308,9 @@ def test_aiter_moe_w4a16(m, k, n, e, topk, group_size, has_zp, dtype, inplace, r
msg = (f"[aiter_moe_w4a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.01, msg=msg)
ret_output = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
# Main: run tests
......@@ -338,4 +352,5 @@ if __name__ == "__main__":
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("test_aiter_moe_with_config_w4a16.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
......@@ -437,8 +437,8 @@ def perchannel_w8a8_triton(input,
triton_moe_sum(output_triton2.view(*output_triton2.shape),
out_hidden_states)
else:
aiter.moe_c_moe_sum(output_triton2.view(*output_triton2.shape),
out_hidden_states,topk_ids)
aiter.moe_c_moe_sum_opt_v2(output_triton2.view(*output_triton2.shape),
out_hidden_states,1.0)
# print("**************************************triton")
# print(out_hidden_states)
......@@ -469,7 +469,7 @@ def _run_triton_ref(hidden_states, w1, w2, topk_weights, topk_ids, dtype, block_
)
return perchannel_w8a8_triton(hidden_states,w1,w2,w1_scale,w2_scale,topk_weights, topk_ids,sorted_token_ids, expert_ids, num_tokens_post_padded,dtype,True)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph = True)
def _run_aiter_moe_perf(
hidden_states,
w1,
......@@ -490,7 +490,10 @@ def _run_aiter_moe_perf(
expert_map,
):
mortal_input = hidden_states.clone() #保证inplace操作的正确性
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
......@@ -826,7 +829,8 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None,
)
msg = f"[aiter_triton_w4a8] {m=}"
checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg)
triton_check_ret = checkAllclose(ref_out, aiter_triton_out, rtol=0.01, atol=100, msg=msg)
ret["triton_accuracy"] = "passed" if triton_check_ret == 0 else (1-triton_check_ret)
ret["aiter_triton_us"] = aiter_triton_us
aiter_us = 1.0
......@@ -854,10 +858,13 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, dtype):
expert_map=None,
)
print("aiter_out",aiter_out)
print("aiter_triton_out",aiter_triton_out)
msg = f"[aiter_moe_w4a8] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
ret["moe_accuracy"] = "passed" if check_ret == 0 else (1-check_ret)
ret["aiter_moe_us"] = aiter_us
return ret
......@@ -868,14 +875,14 @@ if __name__ == "__main__":
e = 256
topk = 8
k = 6144
n = 2048
k = 7168
n = 256
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w4a8")
aiter.logger.info("=" * 60)
test_tokens = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048,4096,6144,8192,16384]
# test_tokens = [1]
# test_tokens = [2048]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
......
# Test for get_aiter_moe_config and aiter_moe with W8A16 (int8 weight, fp16/bf16 activation)
#
# Mirrors test_aiter_moe_with_config_w16a16.py but for the INT8_W8A16 quant
# type. The W8A16 ASM backend is not currently supported, so part 3
# (ASM shuffle vs non-shuffle) is omitted intentionally.
import torch
import pandas as pd
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import w8a16_marlin_weight_1, w8a16_marlin_weight_2
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Weight preparation helpers (W8A16 - int8 weights, fp16/bf16 activations)
# ---------------------------------------------------------------------------
def prepare_w8a16_inputs(m, k, n, e, topk, dtype):
"""Build all tensors needed to run a w8a16 (int8 weight) MOE test.
Layout:
- input: [m, k] fp16/bf16
- w1: [E, 2n, k] int8 (gate + up, gated activation)
- w2: [E, k, n] int8 (down)
- w1_scale: [E, 2n, 1] fp32 (per-output-channel)
- w2_scale: [E, k, 1] fp32 (per-output-channel)
- w1_marlin / w2_marlin: marlin-shuffled int8 weights consumed by the
moe_c W8A16 kernel.
The non-shuffled int8 weights together with the per-channel scales are
used to build the torch reference (dequantized to fp16/bf16).
"""
torch.manual_seed(0)
# Keep activations small to limit accumulation error in int8 dequant.
input_tensor = (torch.randn((m, k), device="cuda", dtype=dtype)) / 100
w1 = torch.randint(-127, 127, (e, 2 * n, k), device="cuda", dtype=torch.int8)
w2 = torch.randint(-127, 127, (e, k, n), device="cuda", dtype=torch.int8)
# Per-channel scales. The moe_c W8A16 marlin kernel expects scales in
# the activation dtype (fp16/bf16), so we keep them in `dtype`. Small
# magnitude keeps dequantized weights in a reasonable range.
w1_scale = (torch.randn((e, 2 * n, 1), device="cuda", dtype=dtype)).abs() * 1e-3
w2_scale = (torch.randn((e, k, 1), device="cuda", dtype=dtype)).abs() * 1e-3
# Marlin-shuffled weights for the moe_c W8A16 kernel.
w1_marlin = w8a16_marlin_weight_1(w1)
w2_marlin = w8a16_marlin_weight_2(w2)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1": w1,
"w2": w2,
"w1_marlin": w1_marlin,
"w2_marlin": w2_marlin,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w8a16)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, dtype):
"""Validate that get_aiter_moe_config returns a sane W8A16 config or
gracefully reports no-solution."""
N1 = 2 * n # gate + up
N2 = k # down / hidden_size
K = k # model dimension
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
# ASM backend is intentionally unsupported for W8A16.
assert moe_cfg.solution_type in (
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
), f"Unexpected solution_type for W8A16: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.INT8_W8A16
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w8a16] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"no solution found"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w8a16
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale, fc2_scale):
return torch_moe(
hidden_states, w1, w2, topk_weights, topk_ids,
fc1_scale=fc1_scale, fc2_scale=fc2_scale,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(mortal_input, w1, w2, topk_weights, topk_ids, moe_config,
inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts,
expert_map, routed_scaling_factor)
def test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe (W8A16) -> compare with
torch reference (dequantized weights)."""
N1 = 2 * n
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.INT8_W8A16,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w8a16] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w8a16_inputs(m, k, n, e, topk, dtype)
# Pick weights based on the backend:
# - moe_c expects marlin-shuffled int8 weights
# - triton expects plain (E, N, K) int8 weights
if backend == MoeSolutionType.MOE_C:
w1_run, w2_run = data["w1_marlin"], data["w2_marlin"]
else:
w1_run, w2_run = data["w1"], data["w2"]
# Torch reference (dequantize int8 weights via per-channel scales)
ref_out, _ = _run_torch_ref(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
data["w1_scale"], data["w2_scale"],
)
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=w1_run,
w2=w2_run,
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scale"],
w2_scale=data["w2_scale"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None, # per-channel
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w8a16] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
# int8 dequant + bf16/fp16 accumulation order differs between the
# reference and the fused kernels; use a relaxed tolerance similar to
# the W16A16 test.
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.05, atol=0.5, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "N1": N1, "N2": N2, "K": K, "e":e, "topk":topk,"backend": backend, "us": aiter_us, "accuracy": ret_output}
# ---------------------------------------------------------------------------
# Main: run tests
# ---------------------------------------------------------------------------
if __name__ == "__main__":
dtype = dtypes.fp16 # W8A16 moe_c kernel was validated with fp16
# Use shape that matches the tuned W8A16 moe_c configs under
# aiter/moe_c_configs (E=128, N=2048).
e = 128
topk = 8
k = 6144 # model_dim
n = 2048 # intermediate_size
inplace = False
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a16")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
# --- Part 2: test aiter_moe end-to-end (w8a16) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a16")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a16(m, k, n, e, topk, dtype, inplace,
routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a16.csv", index=False)
\ No newline at end of file
......@@ -104,7 +104,10 @@ def _run_aiter_moe_perf(
expert_map,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
......@@ -248,8 +251,9 @@ def test_aiter_moe_w8a8(m, k, n, e, topk, block_shape, dtype):
)
msg = f"[aiter_moe_w8a8] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us}
check_ret = checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
passed = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed}
if __name__ == "__main__":
......@@ -278,3 +282,4 @@ if __name__ == "__main__":
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
df.to_csv("test_aiter_moe_with_config_w8a8_blockwise.csv", index=False)
......@@ -32,7 +32,7 @@ def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1,testGraph=True)
def _run_aiter_moe_perf(
hidden_states,
w1,
......@@ -51,9 +51,13 @@ def _run_aiter_moe_perf(
block_shape,
global_num_experts,
expert_map,
out_dtype,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
......@@ -73,6 +77,7 @@ def _run_aiter_moe_perf(
block_shape,
global_num_experts,
expert_map,
output_dtype = out_dtype
)
......@@ -86,12 +91,20 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
"""
torch.manual_seed(0)
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
if dtype == dtypes.fp8:
input_tensor = torch.randn((m, k), dtype=dtypes.fp32, device="cuda") / 100
w1_fp = torch.randn((e, 2 * n, k), dtype=dtypes.fp32, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtypes.fp32, device="cuda")
else:
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 100
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
if quant_type == MoeQuantType.FP8_W8A8:
# FP8 channel-wise quantization via pertoken_quant
if dtype == dtypes.fp8:
input_tensor_q, a1_scales = pertoken_quant(input_tensor, quant_dtype=dtypes.fp8)
w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8)
w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8)
else:
......@@ -106,6 +119,9 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_scales = max_vals_w2 / 127.0 # (e, k, 1)
w2_qweight = (w2_fp / max_vals_w2 * 127.0).round().clamp(min=-128, max=127).to(torch.int8)
if dtype == dtypes.fp8:
score = torch.randn((m, e), dtype=dtypes.fp32, device="cuda")
else:
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
......@@ -114,7 +130,8 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
w2_qweight_shuffle = moe_layout_shuffle_gemm2(w2_qweight).view(*w2_qweight.shape)
return {
"input": input_tensor,
"input": input_tensor_q if dtype == dtypes.fp8 else input_tensor,
"a1_scales": a1_scales if dtype == dtypes.fp8 else None,
"w1_ref": w1_fp,
"w2_ref": w2_fp,
"w1_qweight": w1_qweight,
......@@ -125,6 +142,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuant
"w2_scales": w2_scales,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"input_ref": input_tensor, # original fp32/bf16 input for reference
}
......@@ -164,7 +182,7 @@ def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
return status, moe_cfg
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type=MoeQuantType.W8A8, inplace=False):
"""End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
......@@ -174,7 +192,7 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
K=k,
top_k=topk,
block_size=0,
dtype=dtype,
dtype=in_dtype,
quant_type=quant_type,
)
......@@ -183,11 +201,11 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type)
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, in_dtype, quant_type)
# Torch reference uses original fp weights directly (no scales needed)
# Torch reference uses pre-quantization input (fp32/bf16) since fp8 is unsupported
ref_out, _ = _run_torch_ref(
data["input"],
data["input_ref"],
data["w1_ref"],
data["w2_ref"],
data["topk_weights"],
......@@ -209,27 +227,28 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuant
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=True,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=None,
a1_scale=data["a1_scales"],
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
out_dtype=out_dtype,
)
print("aiter_out",aiter_out)
print("ref_out",ref_out)
# print("aiter_out",aiter_out)
# print("ref_out",ref_out)
# Compare in aiter_out's dtype since reference may be fp32 (from pre-quant input)
msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us}
check_ret = checkAllclose(ref_out.to(aiter_out.dtype), aiter_out, rtol=0.01, atol=100, msg=msg)
passed = "passed" if check_ret == 0 else (1-check_ret)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed, }
if __name__ == "__main__":
......@@ -239,36 +258,39 @@ if __name__ == "__main__":
parser.add_argument(
"--quant",
choices=["int8", "fp8"],
default="int8",
default="fp8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
)
args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
inplace = False # in_dtype != out_dtype时,不能为True
# for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ;
dtype = dtypes.bf16
e = 256
in_dtype = dtypes.fp8
out_dtype = dtypes.bf16
e = 192
topk = 8
k = 6144
n = 320
k = 4096
n = 384
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype, quant_type)
test_get_config(m, k, n, e, topk, in_dtype, quant_type)
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type)
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type, inplace)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("w8a8_channelwise.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
# Test for get_aiter_moe_config and aiter_moe with w8a8 tensor-wise quantization
import argparse
import pandas as pd
import torch
from aiter.fused_moe import fused_topk
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.ops.quant import pertoken_quant
import aiter
torch.set_default_device("cuda")
def compare_tensors(
tensor1: torch.Tensor,
tensor2: torch.Tensor,
atol: float = 1e-2,
rtol: float = 1e-2
) -> None:
"""
比较两个任意维度的PyTorch张量的差异,支持绝对误差和相对误差阈值。
直接输出详细比较结果,包括每个元素在原始张量中的多维坐标。
无返回值。
参数:
tensor1: 第一个张量(如Triton模型输出)
tensor2: 第二个张量(如PyTorch模型输出)
atol: 绝对误差阈值,默认1e-5
rtol: 相对误差阈值,默认1e-8
"""
# -------------------------- 1. 输入合法性校验 --------------------------
if not isinstance(tensor1, torch.Tensor) or not isinstance(tensor2, torch.Tensor):
raise TypeError("输入必须是PyTorch张量(torch.Tensor)")
if tensor1.shape != tensor2.shape:
raise ValueError(f"张量形状不匹配! tensor1形状: {tensor1.shape}, tensor2形状: {tensor2.shape}")
if tensor1.device != tensor2.device:
tensor2 = tensor2.to(tensor1.device)
print(f"警告:张量设备不一致,已将tensor2转移到{tensor1.device}")
# -------------------------- 2. 核心差异计算 --------------------------
abs_diff = torch.abs(tensor1 - tensor2)
denom = torch.maximum(torch.abs(tensor1), torch.abs(tensor2))
rel_diff = abs_diff / (denom + 1e-12)
match_mask = (abs_diff <= atol) | (rel_diff <= rtol)
# -------------------------- 3. 展平张量 --------------------------
tensor1_flat = tensor1.flatten()
tensor2_flat = tensor2.flatten()
abs_diff_flat = abs_diff.flatten()
match_mask_flat = match_mask.flatten()
# -------------------------- 4. 收集一维索引 --------------------------
def get_indices_1d(mask: torch.Tensor) -> list:
indices = torch.nonzero(mask).squeeze(dim=1)
return indices.tolist() if indices.numel() > 0 else []
match_indices_1d = get_indices_1d(match_mask_flat)
mismatch_indices_1d = get_indices_1d(~match_mask_flat)
# -------------------------- 5. 总体统计信息 --------------------------
total = tensor1_flat.numel()
matched = len(match_indices_1d)
mismatched = len(mismatch_indices_1d)
match_rate = matched / total if total > 0 else 0.0
max_abs_diff = abs_diff.max().item() if total > 0 else 0.0
avg_abs_diff = abs_diff.mean().item() if total > 0 else 0.0
# -------------------------- 6. 格式化输出 --------------------------
print("=" * 60)
print("张量比较结果汇总")
print("=" * 60)
print(f"张量形状: {tensor1.shape} | 总元素数: {total}")
print(f"阈值设置: 绝对误差(atol)={atol:.2e}, 相对误差(rtol)={rtol:.2e}")
print("-" * 60)
print(f"匹配元素数: {matched} ({match_rate:.2%})")
print(f"不匹配元素数: {mismatched} ({1 - match_rate:.2%})")
print(f"最大绝对差异: {max_abs_diff:.6f}")
print(f"平均绝对差异: {avg_abs_diff:.6f}")
print("=" * 60)
# -------------------------- 7. 输出匹配/不匹配示例 --------------------------
def print_sample(name: str, indices_1d: list, max_samples: int = 3, elem_per_sample: int = 10) -> None:
if not indices_1d:
print(f"\n{name}样本】无数据")
return
print(f"\n{name}样本】(最多展示{max_samples}组,每组{elem_per_sample}个元素)")
print("-" * 50)
num_samples = min(max_samples, (len(indices_1d) + elem_per_sample - 1) // elem_per_sample)
for i in range(num_samples):
start = i * elem_per_sample
end = start + elem_per_sample
sample_indices_1d = indices_1d[start:end]
# <<< 关键改动: 使用 torch.unravel_index 转换为多维坐标 >>>
# 此函数能处理任意维度
sample_coords = torch.unravel_index(torch.tensor(sample_indices_1d), tensor1.shape)
# 将结果从张量元组转换为坐标元组列表
sample_coords_list = list(zip(*[coord.tolist() for coord in sample_coords]))
print(f"\n{i+1}组:")
print(f" 原始多维坐标: {sample_coords_list}")
print(f" tensor1: {[round(tensor1_flat[idx].item(), 6) for idx in sample_indices_1d]}")
print(f" tensor2: {[round(tensor2_flat[idx].item(), 6) for idx in sample_indices_1d]}")
print(f" 绝对差异: {[round(abs_diff_flat[idx].item(), 6) for idx in sample_indices_1d]}")
print_sample("匹配", match_indices_1d, max_samples=2)
print_sample("不匹配", mismatch_indices_1d, max_samples=3)
print("\n" + "=" * 60)
def _run_aiter_moe(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
out_dtype,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
output_dtype=out_dtype,
)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1, testGraph=True)
def _run_aiter_moe_perf(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
out_dtype,
):
if inplace:
mortal_input = hidden_states.clone() # 保证inplace操作的正确性
else:
mortal_input = hidden_states
return aiter_moe(
mortal_input,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
output_dtype=out_dtype,
)
def _quantize_tensorwise_int8(weight):
max_vals = torch.abs(weight.to(torch.float32)).amax(dim=(1, 2), keepdim=True)
max_vals = max_vals.clamp(min=1e-5)
scales = max_vals / 127.0
qweight = (weight / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8)
return qweight, scales
def _quantize_tensorwise_fp8(weight):
finfo = torch.finfo(dtypes.fp8)
max_vals = torch.abs(weight.to(torch.float32)).amax(dim=(1, 2), keepdim=True)
max_vals = max_vals.clamp(min=1e-5)
scales = max_vals / finfo.max
qweight = (weight / scales).clamp(min=finfo.min, max=finfo.max).to(dtypes.fp8)
print("==================qweight, scales")
print(qweight.shape)
print(scales.shape)
return qweight, scales
def prepare_w8a8_tensorwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Prepare tensor-wise quantized w8a8 inputs.
For int8 (W8A8): weights quantized to torch.int8 with one scale per expert.
For fp8 (FP8_W8A8): weights quantized to float8 with one scale per expert.
Scale shape must be (e, 1, 1), which selects the tensorwise Marlin path.
block_shape is None for this path.
"""
torch.manual_seed(0)
if dtype == dtypes.fp8:
input_tensor = torch.randn((m, k), dtype=dtypes.fp32, device="cuda") / 10
w1_fp = torch.randn((e, 2 * n, k), dtype=dtypes.fp32, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtypes.fp32, device="cuda")
else:
input_tensor = torch.randn((m, k), dtype=dtype, device="cuda") / 10
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
input_for_aiter = input_tensor
a1_scales = None
if quant_type == MoeQuantType.FP8_W8A8:
# Activation remains per-token quantized; tensorwise only applies to B scales.
if dtype == dtypes.fp8:
input_for_aiter, a1_scales = pertoken_quant(input_tensor, quant_dtype=dtypes.fp8)
w1_qweight, w1_scales = _quantize_tensorwise_fp8(w1_fp)
w2_qweight, w2_scales = _quantize_tensorwise_fp8(w2_fp)
else:
w1_qweight, w1_scales = _quantize_tensorwise_int8(w1_fp)
w2_qweight, w2_scales = _quantize_tensorwise_int8(w2_fp)
if dtype == dtypes.fp8:
score = torch.randn((m, e), dtype=dtypes.fp32, device="cuda")
else:
score = torch.randn((m, e), dtype=dtype, device="cuda")
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
# moe_c backend needs layout-shuffled weights
w1_qweight_shuffle = moe_layout_shuffle_gemm1(w1_qweight).view(*w1_qweight.shape)
w2_qweight_shuffle = moe_layout_shuffle_gemm2(w2_qweight).view(*w2_qweight.shape)
return {
"input": input_for_aiter,
"a1_scales": a1_scales,
"w1_qweight": w1_qweight,
"w2_qweight": w2_qweight,
"w1_qweight_shuffle": w1_qweight_shuffle,
"w2_qweight_shuffle": w2_qweight_shuffle,
"w1_scales": w1_scales,
"w2_scales": w2_scales,
"w1_scales_channelwise": w1_scales.expand(e, 2 * n, 1).contiguous(),
"w2_scales_channelwise": w2_scales.expand(e, k, 1).contiguous(),
"topk_weights": topk_weights,
"topk_ids": topk_ids,
}
def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Test get_aiter_moe_config for tensor-wise w8a8 (block_size=0)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=0,
dtype=dtype,
quant_type=quant_type,
)
tag = f"get_config_{quant_type}_tw"
if status:
assert moe_cfg.quant_type == quant_type
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.MOE_C,
MoeSolutionType.TRITON,
MoeSolutionType.CK,
)
assert moe_cfg.config is not None
aiter.logger.info(
f"[{tag}] {m=}, solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
aiter.logger.info(f"[{tag}] {m=}, no solution found")
return status, moe_cfg
def test_aiter_moe_w8a8_tensorwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type=MoeQuantType.W8A8, inplace=False):
"""End-to-end test of aiter_moe with tensor-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
N1=2 * n,
N2=k,
K=k,
top_k=topk,
block_size=0,
dtype=in_dtype,
quant_type=quant_type,
)
tag = f"aiter_moe_{quant_type}_tw"
if not status:
aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None
# Tensorwise scale shape (E, 1, 1) is currently implemented by the moe_c Marlin path.
if moe_cfg.solution_type != MoeSolutionType.MOE_C:
aiter.logger.info(f"[{tag}] SKIP {m=}: tensorwise requires moe_c, got {moe_cfg.solution_type}")
return None
data = prepare_w8a8_tensorwise_inputs(m, k, n, e, topk, in_dtype, quant_type)
# The reference uses the existing channelwise moe_c path with scales expanded
# from (E, 1, 1) to (E, out_dim, 1). This isolates tensorwise kernel logic.
ref_out = _run_aiter_moe(
hidden_states=data["input"],
w1=data["w1_qweight_shuffle"],
w2=data["w2_qweight_shuffle"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales_channelwise"],
w2_scale=data["w2_scales_channelwise"],
w1_zp=None,
w2_zp=None,
a1_scale=data["a1_scales"],
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
out_dtype=out_dtype,
)
aiter_out = _run_aiter_moe(
hidden_states=data["input"],
w1=data["w1_qweight_shuffle"],
w2=data["w2_qweight_shuffle"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=data["a1_scales"],
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
out_dtype=out_dtype,
)
msg = f"[{tag}] {m=} {k=} {n=} {e=}, backend={moe_cfg.solution_type}"
compare_tensors(aiter_out,ref_out)
print("===============m k n e ===================")
print(m,k,n,e)
check_ret = checkAllclose(ref_out.to(aiter_out.dtype), aiter_out, rtol=0.01, atol=100, msg=msg)
_, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1_qweight_shuffle"],
w2=data["w2_qweight_shuffle"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="silu",
w1_scale=data["w1_scales"],
w2_scale=data["w2_scales"],
w1_zp=None,
w2_zp=None,
a1_scale=data["a1_scales"],
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
out_dtype=out_dtype,
)
passed = "passed" if check_ret == 0 else (1 - check_ret)
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us, "accuracy": passed}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test aiter_moe with tensor-wise w8a8 quantization",
)
parser.add_argument(
"--quant",
choices=["int8", "fp8"],
default="fp8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
)
args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
inplace = False # in_dtype != out_dtype时,不能为True
in_dtype = dtypes.fp8 if quant_type == MoeQuantType.FP8_W8A8 else dtypes.bf16
out_dtype = dtypes.bf16
e = 256
topk = 8
k = 2048
n = 2048
# for k in [2048,4096,6144,7168]:
# for n in [128,256,512,1024,2048]:
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} tensor-wise")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192, 16384]
for m in test_tokens:
test_get_config(m, k, n, e, topk, in_dtype, quant_type)
aiter.logger.info("=" * 60)
aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} tensor-wise")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8_tensorwise(m, k, n, e, topk, in_dtype, out_dtype, quant_type, inplace)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
df.to_csv("w8a8_tensorwise.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
# SPDX-License-Identifier: MIT
import torch
import aiter
import pandas as pd
from aiter import dtypes
# from ater.test_common import checkAllclose, perftest
......@@ -26,6 +27,16 @@ tensor0.copy_(random_data0)
print("shape0", shape0)
print("strride0:", stride0)
def get_profiler_totals(prof):
totals = {}
table = prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)
for line in table.splitlines():
if line.startswith("Self CPU time total:"):
totals["CPU total time"] = line.split(":", 1)[1].strip()
elif line.startswith("Self CUDA time total:"):
totals["CUDA total time"] = line.split(":", 1)[1].strip()
return totals
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=True,
......@@ -37,6 +48,7 @@ with profile(
# cache_flush1 = torch.randn(10000, 10000, requires_grad=True, device="cuda", dtype=dtypes.fp32).to(dtypes.i32)
result = torch.sigmoid(tensor0)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
torch_totals = get_profiler_totals(prof)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
......@@ -49,7 +61,23 @@ with profile(
# cache_flush1 = torch.randn(10000, 10000, requires_grad=True, device="cuda", dtype=dtypes.fp32).to(dtypes.i32)
output = aiter.sigmoid(tensor0)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
aiter_totals = get_profiler_totals(prof)
result_equal = torch.equal(result, output)
summary = pd.DataFrame([
{
"metric": "CPU total time",
"torch.sigmoid": torch_totals.get("CPU total time"),
"aiter.sigmoid": aiter_totals.get("CPU total time"),
},
{
"metric": "CUDA total time",
"torch.sigmoid": torch_totals.get("CUDA total time"),
"aiter.sigmoid": aiter_totals.get("CUDA total time"),
}
])
print(torch.equal(result, output))
print("result:", result)
print("output:", output)
equal_msg = f"Whether the two outputs are equal: {str(result_equal)}"
summary.to_csv("test_aiter_sigmoid.csv", index=False)
with open("test_aiter_sigmoid.csv", "a", encoding="utf-8") as f:
f.write(equal_msg + "\n")
......@@ -8,6 +8,7 @@ from typing_extensions import Optional
import torch
import torch.distributed as dist
import pandas as pd
from aiter import dtypes
from aiter.dist.communication_op import tensor_model_parallel_all_reduce
......@@ -35,10 +36,18 @@ def allreduce_custom(
x,
withGraph=False,
distributed_init_method: Optional[str] = None,
enable_register_for_capturing: bool = True,
):
device = torch.device(f"cuda:{rankID}")
torch.cuda.set_device(device)
# init
# Forward the user-requested capturing-registration policy down to the
# CustomAllreduce constructor via the env var consumed inside
# CudaCommunicator. Must be set BEFORE init_distributed_environment so
# the worker process picks it up when the communicator is built.
os.environ["AITER_AR_ENABLE_REG_CAPTURE"] = (
"1" if enable_register_for_capturing else "0"
)
logger.info(f"RANK: {rankID} {tp_size} init_process_group...")
set_custom_all_reduce(True)
init_distributed_environment(
......@@ -92,6 +101,7 @@ def test_allreduce_custom(
dtype,
withGraph=False,
distributed_init_method: Optional[str] = None,
enable_register_for_capturing: bool = True,
):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
......@@ -104,19 +114,38 @@ def test_allreduce_custom(
rets.append(
pool.apply_async(
allreduce_custom,
args=(tp_size, pp_size, i, x, withGraph, distributed_init_method),
args=(
tp_size,
pp_size,
i,
x,
withGraph,
distributed_init_method,
enable_register_for_capturing,
),
)
)
pool.close()
pool.join()
rets = [el.get() for el in rets]
all_us = [us for _, us in rets]
max_err = 0.0
for out, us in rets:
msg = f"test_allreduce_custom: {shape=} {dtype=} {withGraph=} {us:>8.2f}"
checkAllclose(ref, out.to(ref), msg=msg)
msg = (
f"test_allreduce_custom: {shape=} {dtype=} "
f"{withGraph=} reg_cap={enable_register_for_capturing} {us:>8.2f}"
)
err = checkAllclose(ref, out.to(ref), msg=msg)
max_err = max(max_err, err)
return {
"min_us": min(all_us),
"max_us": max(all_us),
"err": max_err,
}
l_dtype = ["fp16", "bf16"]
l_shape = [(128, 8192)]
l_shape = [(2, 7168), (128, 8192)]
parser = argparse.ArgumentParser(description="config input of test")
parser.add_argument(
......@@ -138,6 +167,23 @@ parser.add_argument(
default=None,
help="shape. e.g. -s 128,8192",
)
parser.add_argument(
"-g",
"--with-graph",
type=lambda x: str(x).lower() in ["true", "1", "yes"],
default=True,
help="use CUDA graph (default: True). e.g. -g true or -g false",
)
parser.add_argument(
"--reg-capturing",
type=str,
choices=["true", "false", "both"],
default="both",
help=(
"whether CustomAllreduce.enable_register_for_capturing is True/False. "
"'both' (default) exercises both paths as a regression sweep."
),
)
if __name__ == "__main__":
......@@ -149,16 +195,43 @@ if __name__ == "__main__":
l_dtype = [dtypes.d_dtypes[args.dtype]]
if args.shape is not None:
l_shape = [args.shape]
if args.reg_capturing == "true":
l_reg = [True]
elif args.reg_capturing == "false":
l_reg = [False]
else:
l_reg = [True, False]
df = []
for dtype in l_dtype:
for shape in l_shape:
test_allreduce_custom(
for reg in l_reg:
ret = test_allreduce_custom(
8,
1,
shape,
dtype,
withGraph=True,
withGraph=args.with_graph,
distributed_init_method=get_distributed_init_method(
get_ip(), get_open_port()
),
enable_register_for_capturing=reg,
)
df.append(ret)
df = pd.DataFrame(df)
show_cols = [
"tp_size",
"shape",
"dtype",
"withGraph",
"enable_register_for_capturing",
"min_us",
"max_us",
"err",
]
show_cols = [c for c in show_cols if c in df.columns]
df[show_cols].to_csv("test_custom_allreduce.csv", index=False)
logger.info(
"custom allreduce summary (markdown):\n%s",
df[show_cols].to_markdown(index=False),
)
# test_allreduce_custom(8, 1, shape, dtype, withGraph=False)
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
import os
from typing import Optional
......@@ -8,6 +9,7 @@ import torch.nn.functional as F
import torch.distributed as dist
import argparse
import itertools
import pandas as pd
from aiter import dtypes
from aiter.dist.parallel_state import (
......@@ -23,6 +25,7 @@ from aiter.dist.utils import get_open_port, get_distributed_init_method, get_ip
from aiter.dist.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_fused_allreduce_rmsnorm,
tensor_model_parallel_fused_allreduce_rmsnorm_quant,
)
from aiter.test_common import (
checkAllclose,
......@@ -37,35 +40,6 @@ logger = logging.getLogger("aiter")
set_start_method("spawn", force=True)
def prebuild_rmsnorm_module(dtype):
# Trigger JIT build in the parent process once to avoid multi-process
# contention on file-baton lock in workers.
if not torch.cuda.is_available():
return
dev = torch.device("cuda:0")
x = torch.randn((1, 128), dtype=dtype, device=dev)
ar_out = torch.randn_like(x)
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
weight = torch.randn((x.shape[-1],), dtype=dtype, device=dev)
try:
print(f"Prebuilding rmsnorm module for {dtype=}.")
aiter.rmsnorm2d_fwd_with_add(
out,
ar_out,
x,
residual_out,
weight,
1e-6,
)
torch.cuda.synchronize(dev)
except Exception as e:
raise RuntimeError(
f"Failed to prebuild rmsnorm module for {dtype=}. "
"Please check aiter JIT build environment before running multi-process tests."
) from e
print(f"Prebuilding rmsnorm module for {dtype=} done.")
def fused_ar_rmsnorm(
tp_size,
pp_size,
......@@ -75,6 +49,7 @@ def fused_ar_rmsnorm(
eps,
withGraph=False,
distributed_init_method: Optional[str] = None,
post_per_token_quant: bool = False,
):
device = torch.device(f"cuda:{rankID}")
torch.cuda.set_device(device)
......@@ -100,9 +75,16 @@ def fused_ar_rmsnorm(
graph = torch.cuda.CUDAGraph()
with graph_capture() as gc:
with torch.cuda.graph(graph, stream=gc.stream):
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(
if not post_per_token_quant:
out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm(
x, x, weight, eps
)
else:
out, res_out, scale_out = (
tensor_model_parallel_fused_allreduce_rmsnorm_quant(
x, x, weight, eps
)
)
out.fill_(0)
res_out.fill_(0)
......@@ -111,17 +93,32 @@ def fused_ar_rmsnorm(
graph.replay()
_, us = run_ca()
if not post_per_token_quant:
out = (out, us)
else:
out = (out.float() * scale_out, us)
else:
@perftest()
def run_ca(x):
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(
if not post_per_token_quant:
out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm(
x, x, weight, eps
)
return out
else:
out, res_out, scale_out = (
tensor_model_parallel_fused_allreduce_rmsnorm_quant(
x, x, weight, eps
)
)
return out, scale_out
if not post_per_token_quant:
out = run_ca(x)
else:
out = run_ca(x)
out = (out[0][0].float() * out[0][1], out[1])
# destroy
if dist.is_initialized():
......@@ -166,7 +163,7 @@ def get_acc_value_with_cudagraph(
with graph_capture() as gc:
with torch.cuda.graph(graph, stream=gc.stream):
# out = torch.empty_like(x)
res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(
out, res_out = tensor_model_parallel_fused_allreduce_rmsnorm(
x, x, weight, eps
)
out.fill_(0)
......@@ -218,7 +215,7 @@ def get_acc_value_only(
torch.cuda.synchronize()
for i in range(loop_time):
res, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, x, weight, eps)
out, res = tensor_model_parallel_fused_allreduce_rmsnorm(x, x, weight, eps)
# destroy
if dist.is_initialized():
......@@ -272,7 +269,9 @@ def split_ar_rmsnorm(
x,
residual_out,
weight,
eps)
eps,
0,
)
out.fill_(0)
@perftest()
......@@ -294,7 +293,8 @@ def split_ar_rmsnorm(
x,
residual_out,
weight,
eps
eps,
0,
)
return out
......@@ -308,129 +308,6 @@ def split_ar_rmsnorm(
return out
def split_vs_fused_ar_rmsnorm(
tp_size,
pp_size,
rankID,
x,
weight,
eps,
withGraph=False,
distributed_init_method: Optional[str] = None,
):
device = torch.device(f"cuda:{rankID}")
torch.cuda.set_device(device)
logger.info(f"RANK: {rankID} {tp_size} init_process_group...")
set_custom_all_reduce(True)
init_distributed_environment(
world_size=tp_size,
rank=rankID,
distributed_init_method=distributed_init_method,
)
ensure_model_parallel_initialized(tp_size, pp_size)
x = x.to(device)
weight = weight.to(device)
group = get_tp_group().device_group
dist.all_reduce(torch.zeros(1).cuda(), group=group)
torch.cuda.synchronize()
# For split-vs-fused numerical comparison, eager mode is sufficient and
# avoids graph-capture + collective interaction risks in one worker lifecycle.
@perftest()
def run_split(inp):
ar_out = tensor_model_parallel_all_reduce(inp)
split_out = torch.empty_like(ar_out)
residual_out = torch.empty_like(ar_out)
aiter.rmsnorm2d_fwd_with_add(
split_out,
ar_out,
inp,
residual_out,
weight,
eps,
)
return split_out
split_out, split_us = run_split(x)
@perftest()
def run_fused(inp):
_, fused_out = tensor_model_parallel_fused_allreduce_rmsnorm(
inp, inp, weight, eps
)
return fused_out
fused_out, fused_us = run_fused(x)
if dist.is_initialized():
destroy_model_parallel()
destroy_distributed_environment()
torch.cuda.empty_cache()
split_out_cpu = split_out.detach().cpu()
fused_out_cpu = fused_out.detach().cpu()
return rankID, split_out_cpu, fused_out_cpu, split_us, fused_us
@benchmark()
def test_split_ar_rmsnorm(
tp_size,
pp_size,
shape,
dtype,
withGraph=False,
distributed_init_method: Optional[str] = None,
):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
pool = Pool(processes=tp_size)
ref = torch.zeros(shape, dtype=dtype)
rets = []
cpu_rslt = []
weight_list = []
res_inp = []
# print(type(shape[0]), shape[1], ref.device)
m = shape[0]
n = shape[1]
eps = 1e-6
for i in range(tp_size):
x = torch.randn(shape, dtype=dtype)
res_inp.append(x)
ref += x
weight = torch.randn((n,), dtype=dtype)
weight_list.append(weight)
rets.append(
pool.apply_async(
split_ar_rmsnorm,
args=(
tp_size,
pp_size,
i,
x,
weight,
eps,
withGraph,
distributed_init_method,
),
)
)
pool.close()
pool.join()
for i in range(tp_size):
host_rslt = F.rms_norm(
input=(ref + res_inp[i]),
normalized_shape=(ref.shape[-1],),
weight=weight_list[i],
eps=eps,
)
cpu_rslt.append(host_rslt)
rets = [el.get() for el in rets]
for out, us in rets:
msg = f"test_split_ar_rmsnorm: {shape=} {dtype=} {withGraph=} {us:>8.2f}"
# print(cpu_rslt[out.device.index])
checkAllclose(cpu_rslt[out.device.index], out.to(ref),rtol=0.03, atol=0.03, msg=msg)
@benchmark()
def test_fused_ar_rmsnorm(
tp_size,
......@@ -439,6 +316,7 @@ def test_fused_ar_rmsnorm(
dtype,
withGraph=False,
distributed_init_method: Optional[str] = None,
post_per_token_quant: bool = False,
):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
......@@ -449,16 +327,13 @@ def test_fused_ar_rmsnorm(
weight_list = []
res_inp = []
# print(type(shape[0]), shape[1], ref.device)
m = shape[0]
n = shape[1]
eps = 1e-6
for i in range(tp_size):
weight = torch.randn((n,), dtype=dtype)
x = torch.randn(shape, dtype=dtype)
# x = torch.ones(shape, dtype=dtype)
ref = x * tp_size
for i in range(tp_size):
res_inp.append(x)
# print(f"device {i}, x[0][0] = {x[0][0]}")
ref += x
weight = torch.randn((n,), dtype=dtype)
weight_list.append(weight)
rets.append(
pool.apply_async(
......@@ -472,6 +347,7 @@ def test_fused_ar_rmsnorm(
eps,
withGraph,
distributed_init_method,
post_per_token_quant,
),
)
)
......@@ -490,202 +366,40 @@ def test_fused_ar_rmsnorm(
cpu_rslt.append(host_rslt)
rets = [el.get() for el in rets]
all_us = [us for _, us in rets]
atol = 5e-2 if post_per_token_quant else 1e-2
rtol = atol
max_err = 0.0
for out, us in rets:
msg = f"test_fused_ar_rmsnorm: {shape=} {dtype=} {withGraph=} {us:>8.2f}"
# print(cpu_rslt[out.device.index])
checkAllclose(cpu_rslt[out.device.index], out.to(ref),rtol=0.03, atol=0.03 ,msg=msg)
# checkAllclose(ref, out.to(ref), msg=msg)
@benchmark()
def test_split_vs_fused_ar_rmsnorm(
tp_size,
pp_size,
shape,
dtype,
withGraph=False,
distributed_init_method: Optional[str] = None,
):
eps = 1e-6
n = shape[1]
x_list = [torch.randn(shape, dtype=dtype) for _ in range(tp_size)]
weight_list = [torch.randn((n,), dtype=dtype) for _ in range(tp_size)]
compare_dist_init_method = distributed_init_method
if compare_dist_init_method is None:
compare_dist_init_method = get_distributed_init_method(get_ip(), get_open_port())
pool = Pool(processes=tp_size)
rets = []
for i in range(tp_size):
rets.append(
pool.apply_async(
split_vs_fused_ar_rmsnorm,
args=(
tp_size,
pp_size,
i,
x_list[i].clone(),
weight_list[i].clone(),
eps,
withGraph,
compare_dist_init_method,
),
)
)
pool.close()
rank_rets = [el.get() for el in rets]
pool.join()
for rank_id, split_out, fused_out, split_us, fused_us in rank_rets:
diff = (split_out - fused_out).abs()
max_abs = diff.max().item()
rel = diff / fused_out.abs().clamp_min(1e-6)
rel_flat = rel.reshape(-1)
max_rel, max_rel_idx = torch.max(rel_flat, dim=0)
max_rel = max_rel.item()
max_rel_idx = max_rel_idx.item()
split_at_max_rel = split_out.reshape(-1)[max_rel_idx].item()
fused_at_max_rel = fused_out.reshape(-1)[max_rel_idx].item()
msg = (
f"[split_vs_fused]: rank={rank_id} "
f"split_us={split_us:>6.2f}, fused_us={fused_us:>6.2f}, "
f"max_abs={max_abs:.8f}, max_rel={max_rel:.8f}, "
f"idx={max_rel_idx}, split_val={split_at_max_rel:.8f}, fused_val={fused_at_max_rel:.8f}. "
)
print(msg)
msg=f"split_vs_fused: {shape=} {dtype=} {withGraph=} rank={rank_id}"
checkAllclose(split_out, fused_out, rtol=0.03, atol=0.03, msg=msg)
def acc_test(
tp_size, pp_size, shape, dtype, distributed_init_method: Optional[str] = None
):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
pool = Pool(processes=tp_size)
ref = torch.zeros(shape, dtype=dtype)
rets = []
cpu_rslt = []
weight_list = []
# print(type(shape[0]), shape[1], ref.device)
m = shape[0]
n = shape[1]
eps = 1e-6
for i in range(tp_size):
x = torch.randn(shape, dtype=dtype)
ref += x
weight = torch.randn((n,), dtype=dtype)
weight_list.append(weight)
rets.append(
pool.apply_async(
get_acc_value_only,
args=(tp_size, pp_size, i, x, weight, eps, 1, distributed_init_method),
)
)
pool.close()
pool.join()
ar_rslt = []
for i, ret in enumerate(rets):
rslt = ret.get()
ar_rslt.append(rslt)
for i in ar_rslt:
checkAllclose(ref, i.to(ref))
def acc_test_cudagraph_on(
tp_size,
pp_size,
shape,
dtype,
loop_time=1,
distributed_init_method: Optional[str] = None,
):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
pool = Pool(processes=tp_size)
ref = torch.zeros(shape, dtype=dtype)
rets = []
cpu_rslt = []
weight_list = []
# print(type(shape[0]), shape[1], ref.device)
m = shape[0]
n = shape[1]
eps = 1e-6
for i in range(tp_size):
x = torch.randn(shape, dtype=dtype)
ref += x
weight = torch.randn((n,), dtype=dtype)
weight_list.append(weight)
rets.append(
pool.apply_async(
get_acc_value_with_cudagraph,
args=(
tp_size,
pp_size,
i,
x,
weight,
eps,
loop_time,
distributed_init_method,
),
)
err = checkAllclose(
cpu_rslt[out.device.index], out.to(ref), msg=msg, atol=atol, rtol=rtol
)
pool.close()
pool.join()
ar_rslt = []
for i, ret in enumerate(rets):
rslt = ret.get()
ar_rslt.append(rslt)
for i in ar_rslt:
checkAllclose(ref, i.to(ref))
# def acc_test(tp_size, pp_size, shape, dtype):
# os.environ["MASTER_ADDR"] = "127.0.0.1"
# os.environ["MASTER_PORT"] = "49373"
# pool = Pool(processes=tp_size)
# ref = torch.zeros(shape, dtype=dtype)
# rets = []
# cpu_rslt = []
# weight_list = []
# # print(type(shape[0]), shape[1], ref.device)
# m = shape[0]
# n = shape[1]
# eps = 1e-6
# for i in range(tp_size):
# x = torch.randn(shape, dtype=dtype)
# print(f"device {i}, x[0][0] = {x[0][0]}")
# ref += x
# weight = torch.randn((n,), dtype=dtype)
# weight_list.append(weight)
# rets.append(
# pool.apply_async(get_acc_value_only, args=(tp_size, pp_size, i, x, weight, eps))
# )
# pool.close()
# pool.join()
# for i in range(tp_size):
# host_rslt = F.rms_norm(
# input=ref, normalized_shape=(ref.shape[-1],), weight=weight_list[i], eps=eps
# )
# cpu_rslt.append(host_rslt)
#
# ar_rslt = []
# for i, ret in enumerate(rets):
# rslt = ret.get()
# ar_rslt.append(rslt)
# for i in range(len(ar_rslt)):
# checkAllclose(cpu_rslt[i], ar_rslt[i].to(ref))
l_dtype = ["bf16"]
l_shape = [(6,7168)]
max_err = max(max_err, err)
# checkAllclose(ref, out.to(ref), msg=msg)
suffix = "quant" if post_per_token_quant else "fused"
return {
f"{suffix}_min_us": min(all_us),
f"{suffix}_max_us": max(all_us),
f"{suffix}_err": max_err,
}
l_dtype = ["fp16", "bf16"]
# (13, 2880): GPT-OSS-120B / GPT-OSS-20B hidden_size (n_bytes=5760, 4096 < 5760 < 8192)
l_shape = [
(13, 512),
(13, 1024),
(13, 2048),
(13, 2880),
(17, 4096),
(17, 7168),
(19, 8192),
]
l_tp = [8]
l_pp = [1]
l_graph = [True, False]
l_graph = [False, True]
parser = argparse.ArgumentParser(description="config input of test")
parser.add_argument(
......@@ -702,10 +416,9 @@ parser.add_argument(
"-s",
"--shape",
type=dtypes.str2tuple,
nargs="?",
const=None,
nargs="*",
default=None,
help="shape. e.g. -s 128,8192",
help="shape(s). e.g. -s 128,8192 256,7168",
)
parser.add_argument(
......@@ -738,6 +451,16 @@ parser.add_argument(
help="open cudagraph. e.g. -g 1",
)
l_test_types = ["fused", "quant"]
parser.add_argument(
"--test",
type=str,
choices=l_test_types,
nargs="*",
default=None,
help="test type(s) to run. e.g. --test fused quant",
)
if __name__ == "__main__":
freeze_support()
......@@ -747,7 +470,7 @@ if __name__ == "__main__":
else:
l_dtype = [dtypes.d_dtypes[args.dtype]]
if args.shape is not None:
l_shape = [args.shape]
l_shape = args.shape
if args.tp is not None:
l_tp = [args.tp]
if args.pp is not None:
......@@ -755,15 +478,14 @@ if __name__ == "__main__":
if args.graphon is not None:
print(args.graphon)
l_graph = [args.graphon]
# Prebuild split-path JIT kernel to avoid workers hanging in JIT lock wait.
for dtype in l_dtype:
prebuild_rmsnorm_module(dtype)
run_tests = args.test if args.test else l_test_types
df = []
for dtype, shape, tp, pp, graph_on in itertools.product(
l_dtype, l_shape, l_tp, l_pp, l_graph
):
test_split_ar_rmsnorm(
row = {}
if "fused" in run_tests:
ret = test_fused_ar_rmsnorm(
tp,
pp,
shape,
......@@ -772,8 +494,11 @@ if __name__ == "__main__":
distributed_init_method=get_distributed_init_method(
get_ip(), get_open_port()
),
post_per_token_quant=False,
)
test_fused_ar_rmsnorm(
row.update(ret)
if "quant" in run_tests:
ret = test_fused_ar_rmsnorm(
tp,
pp,
shape,
......@@ -782,15 +507,26 @@ if __name__ == "__main__":
distributed_init_method=get_distributed_init_method(
get_ip(), get_open_port()
),
post_per_token_quant=True,
)
test_split_vs_fused_ar_rmsnorm(
tp,
pp,
shape,
dtype,
withGraph=graph_on,
distributed_init_method=get_distributed_init_method(
get_ip(), get_open_port()
),
row.update(ret)
df.append(row)
df = pd.DataFrame(df)
show_cols = [
"tp_size",
"shape",
"dtype",
"withGraph",
"fused_min_us",
"fused_max_us",
"fused_err",
"quant_min_us",
"quant_max_us",
"quant_err",
]
show_cols = [c for c in show_cols if c in df.columns]
df[show_cols].to_csv("test_fused_ar_rms.csv", index=False)
logger.info(
"fused allreduce rmsnorm summary (markdown):\n%s",
df[show_cols].to_markdown(index=False),
)
# SPDX-License-Identifier: MIT
import torch
import torch.nn.functional as F
import pandas as pd
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter import dtypes
......@@ -35,15 +36,6 @@ def run_torch(input, weight, bias, eps, residual=None, x_bias=None):
def run_ck(input, weight, bias, eps, residual=None, x_bias=None):
if residual is None:
residual_out = None
output = aiter.layer_norm(input, weight, bias, eps, x_bias)
# output = torch.empty_like(input)
# aiter.layernorm2d_fwd(
# output,
# input,
# weight,
# bias,
# eps
# )
else:
residual_out = torch.empty_like(input)
output = torch.empty_like(input)
......@@ -78,7 +70,17 @@ def test_layernorm2d(dtype, m, n):
(a, *_), avg_a = run_torch(input, weight, bias, 1e-5)
(b, *_), avg_b = run_ck(input, weight, bias, 1e-5)
msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
checkAllclose(a, b, msg=msg)
check_ret = checkAllclose(a, b, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
return {
"m": m,
"n": n,
"dtype": str(dtype),
"torch_us": avg_a,
"ck_us": avg_b,
"uplift": f"{avg_a / avg_b - 1:.1%}",
"accuracy": ret_output,
}
def test_layernorm2d_fuseAdd(dtype, m, n):
......@@ -97,21 +99,57 @@ def test_layernorm2d_fuseAdd(dtype, m, n):
# (c, res_c, *_), avg_c = run_asm(input, weight, bias, 1e-5, residual=res)
msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
checkAllclose(a, b, atol=0.03, msg=msg)
checkAllclose(res_a, res_b, msg="res check")
check_ret = checkAllclose(a, b, atol=0.03, msg=msg)
ret_output = "passed" if check_ret == 0 else (1 - check_ret)
residual_check_ret = checkAllclose(res_a, res_b, msg="res check")
residual_output = "passed" if residual_check_ret == 0 else (1-residual_check_ret)
return {
"m": m,
"n": n,
"dtype": str(dtype),
"torch_us": avg_a,
"ck_us": avg_b,
"uplift": f"{avg_a / avg_b - 1:.1%}",
"accuracy": ret_output,
"residual_accuracy": residual_output,
}
# checkAllclose(a, c, atol=0.03, msg="asm")
# checkAllclose(res_a, res_c, atol=0.01, msg="asm res")
# for dtype in [dtypes.fp16, dtypes.bf16]:
# for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
# for n in [4096, 8192, 16384, 32768, 65536]:
# test_layernorm2d(dtype, m, n)
test_layernorm2d_fuseAdd(dtypes.bf16, 128, 8192)
if __name__ == "__main__":
df = []
df_fuse_add = []
# for dtype in [dtypes.fp16, dtypes.bf16]:
# for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
# for n in [4096, 8192, 16384, 32768, 65536]:
# ret = test_layernorm2d(dtype, m, n)
# if ret is not None:
# df.append(ret)
# ret = test_layernorm2d(dtypes.bf16, 128, 8192)
# if ret is not None:
# df.append(ret)
# print('\nstart fuse add test')
# for dtype in [dtypes.fp16, dtypes.bf16]:
# for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
# for n in [4096, 8192, 16384, 32768, 65536]:
# ret = test_layernorm2d_fuseAdd(dtype, m, n)
# if ret is not None:
# df_fuse_add.append(ret)
ret = test_layernorm2d_fuseAdd(dtypes.bf16, 128, 8192)
if ret is not None:
df_fuse_add.append(ret)
# if df:
# df = pd.DataFrame(df)
# aiter.logger.info(f"layernorm2d summary:\n{df}")
# df.to_csv("test_layernorm2d.csv", index=False)
# print('\nstart fuse add test')
# for dtype in [dtypes.fp16, dtypes.bf16]:
# for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
# for n in [4096, 8192, 16384, 32768, 65536]:
# test_layernorm2d_fuseAdd(dtype, m, n)
if df_fuse_add:
df_fuse_add = pd.DataFrame(df_fuse_add)
aiter.logger.info(f"layernorm2d fuseAdd summary:\n{df_fuse_add}")
df_fuse_add.to_csv("test_layernorm2d_fuseAdd.csv", index=False)
......@@ -559,6 +559,7 @@ for dtype in l_dtype:
ret = test_topk_softmax(dtype, m, e, l_topk)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("topk_softmax.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
df = []
......@@ -575,6 +576,7 @@ for token in l_token:
)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("biased_grouped_topk.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
df = []
......@@ -594,6 +596,7 @@ for token in l_token:
)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("biased_grouped_topk_with_shared_expert.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
df = []
......@@ -619,4 +622,5 @@ for token in l_token:
)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("grouped_topk_with_shared_expert.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
......@@ -76,6 +76,43 @@ def test_moe_sorting_ck(
expert_mask=expert_mask,
)
@perftest()
def test_moe_sorting_ck_no_moebuf(
topk_ids, topk_weights, num_experts, block_size=BLOCK_SIZE_M, expert_mask=None
):
device = topk_ids.device
M, topk = topk_ids.shape
max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk
max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=dtypes.i32, device=device)
sorted_weights = torch.empty(
(max_num_tokens_padded,), dtype=dtypes.fp32, device=device
)
sorted_expert_ids = torch.empty(
(max_num_m_blocks,), dtype=dtypes.i32, device=device
)
tokens_positions_per_expert = torch.empty(
(num_experts * 2,), dtype=dtypes.i32, device=device
)
num_valid_ids = torch.empty((1,), dtype=dtypes.i32, device=device)
if topk_ids.dtype != dtypes.i32:
topk_ids = topk_ids.to(dtypes.i32)
aiter.moe_sorting_fwd(
topk_ids,
topk_weights,
sorted_ids,
sorted_weights,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
None, # moe_buf=None
num_experts,
block_size,
expert_mask,
)
return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids
@benchmark()
def test_moe_sorting(
......@@ -114,29 +151,117 @@ def test_moe_sorting(
print(
f"[perf] {token=}, {model_dim=}, {inter_dim=}, {E=}, {topk=}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
)
checkAllclose(
num_tokens_post_padded_ret = checkAllclose(
num_tokens_post_padded_a,
num_tokens_post_padded_b,
atol=0,
msg="num_tokens_post_padded",
)
num_tokens_post_padded_acc = "passed" if num_tokens_post_padded_ret == 0 else "failed"
mask = sorted_ids_a != (topk << 24 | token)
num_tokens_post_pad = num_tokens_post_padded_a.item()
checkAllclose(
sorted_ids_ret = checkAllclose(
sorted_ids_a[:num_tokens_post_pad],
sorted_ids_b[:num_tokens_post_pad],
msg="sorted_ids",
)
checkAllclose(sorted_weights_a[mask], sorted_weights_b[mask], msg="sorted_weights")
sorted_ids_acc = "passed" if sorted_ids_ret == 0 else "failed"
sorted_weights_ret = checkAllclose(
sorted_weights_a[mask],
sorted_weights_b[mask],
msg="sorted_weights"
)
sorted_weights_acc = "passed" if sorted_weights_ret == 0 else "failed"
expert_mask = sorted_expert_ids_a != -1
checkAllclose(
sorted_expert_ids_ret = checkAllclose(
sorted_expert_ids_a[expert_mask],
sorted_expert_ids_b[expert_mask],
msg="sorted_expert_ids",
)
return {"us": avg_b}
sorted_expert_ids_acc = "passed" if sorted_expert_ids_ret == 0 else "failed"
return {"us": avg_b,
"num_tokens_post_padded_out": num_tokens_post_padded_acc,
"sorted_ids_out": sorted_ids_acc,
"sorted_weights_out": sorted_weights_acc,
"sorted_expert_ids_out": sorted_expert_ids_acc
}
@benchmark()
def test_moe_sorting_none_moebuf(
dtype, token, model_dim, inter_dim, E, topk, has_expert_mask=False
):
input = torch.randn((token, model_dim), dtype=dtype, device="cuda")
score = torch.rand((token, E), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(input, score, topk, True)
expert_mask = (
torch.randint(0, 2, (E,), dtype=topk_ids.dtype, device="cuda")
if has_expert_mask
else None
)
(
sorted_ids_a,
sorted_weights_a,
sorted_expert_ids_a,
num_tokens_post_padded_a,
), avg_a = test_moe_sorting_naive(topk_ids, topk_weights, E, expert_mask)
(
sorted_ids_b,
sorted_weights_b,
sorted_expert_ids_b,
num_tokens_post_padded_b,
), avg_b = test_moe_sorting_ck_no_moebuf(
topk_ids, topk_weights, E, expert_mask=expert_mask
)
print(
f"[perf-none-moebuf] {token=}, {model_dim=}, {inter_dim=}, {E=}, {topk=}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
)
num_tokens_post_padded_ret = checkAllclose(
num_tokens_post_padded_a,
num_tokens_post_padded_b,
atol=0,
msg="num_tokens_post_padded",
)
num_tokens_post_padded_acc = "passed" if num_tokens_post_padded_ret == 0 else "failed"
mask = sorted_ids_a != (topk << 24 | token)
num_tokens_post_pad = num_tokens_post_padded_a.item()
sorted_ids_ret = checkAllclose(
sorted_ids_a[:num_tokens_post_pad],
sorted_ids_b[:num_tokens_post_pad],
msg="sorted_ids",
)
sorted_ids_acc = "passed" if sorted_ids_ret == 0 else "failed"
sorted_weights_ret = checkAllclose(
sorted_weights_a[mask],
sorted_weights_b[mask],
msg="sorted_weights"
)
sorted_weights_acc = "passed" if sorted_weights_ret == 0 else "failed"
expert_mask = sorted_expert_ids_a != -1
sorted_expert_ids_ret = checkAllclose(
sorted_expert_ids_a[expert_mask],
sorted_expert_ids_b[expert_mask],
msg="sorted_expert_ids",
)
sorted_expert_ids_acc = "passed" if sorted_expert_ids_ret == 0 else "failed"
return {"us": avg_b,
"num_tokens_post_padded_out": num_tokens_post_padded_acc,
"sorted_ids_out": sorted_ids_acc,
"sorted_weights_out": sorted_weights_acc,
"sorted_expert_ids_out": sorted_expert_ids_acc
}
import pandas as pd
......@@ -144,11 +269,14 @@ df = []
print("test test_moe_sorting, no expert mask")
for dtype in [dtypes.bf16]:
for m in [1, 7, 31, 64, 128, 256, 163840][:]:
for E in [32, 256][:]:
for E in [3, 5, 32, 40, 256][:]:
for top in [5, 8][:]:
if top > E:
continue
ret = test_moe_sorting(dtype, m, 7168, 4096, E, top)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("moe_sorting_no_expert_mask.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
......@@ -156,11 +284,45 @@ df = []
print("test test_moe_sorting, with expert mask")
for dtype in [dtypes.bf16]:
for m in [1, 7, 31, 64, 128, 256, 163840]:
for E in [32, 256]:
for E in [3, 5, 32, 40, 256]:
for top in [5, 8]:
if top > E:
continue
ret = test_moe_sorting(
dtype, m, 4096, 4096, E, top, has_expert_mask=True
)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("moe_sorting_with_expert_mask.csv", index=False)
aiter.logger.info(f"summary:\n{df}")
df = []
print("test test_moe_sorting_none_moebuf, no expert mask")
for dtype in [dtypes.bf16]:
for m in [1, 7, 31, 64, 128, 256, 163840][:]:
for E in [3, 5, 32, 40, 96, 192, 256][:]:
for top in [5, 8][:]:
if top > E:
continue
ret = test_moe_sorting_none_moebuf(dtype, m, 7168, 4096, E, top)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("moe_sorting_none_moebuf_no_expert_mask.csv", index=False)
aiter.logger.info(f"summary-none-moebuf:\n{df}")
df = []
print("test test_moe_sorting_none_moebuf, with expert mask")
for dtype in [dtypes.bf16]:
for m in [1, 7, 31, 64, 128, 256, 163840]:
for E in [3, 5, 32, 40, 96, 192, 256]:
for top in [5, 8]:
if top > E:
continue
ret = test_moe_sorting_none_moebuf(
dtype, m, 4096, 4096, E, top, has_expert_mask=True
)
df.append(ret)
df = pd.DataFrame(df)
df.to_csv("moe_sorting_none_moebuf_with_expert_mask.csv", index=False)
aiter.logger.info(f"summary-none-moebuf-mask:\n{df}")
......@@ -144,4 +144,11 @@ for (
ret = test_quant(m, n, q_type, q_dtype, h_dtype)
df.append(ret)
df = pd.DataFrame(df)
q_type_name = getattr(q_type, 'name', str(q_type)).split('.')[-1]
q_dtype_name = str(q_dtype).split('.')[-1]
h_dtype_name = str(h_dtype).split('.')[-1]
csv_filename = f"quant_{q_type_name}_{q_dtype_name}_{h_dtype_name}.csv"
df.to_csv(csv_filename, index=False)
aiter.logger.info(f"summary:\n{df}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment