# SPDX-License-Identifier: MIT import pytest import torch import itertools from typing import Optional, List # Add this import at the top from op_tests.utility.scalar_type import ScalarType, scalar_types from op_tests.utility.utils import quantize_weights from op_tests.utility.utils import torch_moe as torch_score_moe from aiter.ops.triton.moe_op import fused_moe from aiter.fused_moe import fused_topk, torch_moe from aiter import ActivationType from aiter.test_common import checkAllclose, perftest,benchmark from aiter import ck_moe, ck_shuffle_moe, dtypes, silu_and_mul, gelu_and_mul, moe_sum from aiter.ops.triton.utils.moe_config_utils import get_optimal_moe_config_func from aiter.ops.triton.fused_moe import fused_experts_impl from aiter.ops.triton.utils.types import torch_to_triton_dtype from aiter.fused_moe_asm_wna16 import fused_experts_asm_impl from aiter import per_token_quant_hip, per_block_quant_wrapper import pandas as pd import aiter import os BLOCK_SIZE_M = 32 MAX_TOKENS = 65536 os.environ["AMDGCN_USE_BUFFER_OPS"] = "1" @perftest(num_warmup=1, num_iters=2) def torch_moe_test( hidden_states, w1, w2, topk_weight, topk_ids, # following for int8 quant fc1_scale=None, # [expert, inter_dim, 1] fc2_scale=None, # [expert, model_dim, 1] fc1_smooth_scale=None, # [expert, 1, model_dim] fc2_smooth_scale=None, # [expert, 1, inter_dim] expert_mask=None, ): return torch_moe( hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale, fc1_smooth_scale, fc2_smooth_scale, expert_mask, ) @perftest(num_warmup=5, num_iters=10,testGraph=False) def asm_fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, dtype, inplace: bool = False, activation: str = "silu", is_gated: Optional[bool] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w4a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, routed_scaling_factor: Optional[float] = 1.0, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None): return fused_experts_asm_impl( hidden_states, w1, w2, topk_weights, topk_ids, dtype, inplace, activation, is_gated, use_fp8_w8a8, use_int8_w8a8, use_int8_w4a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, routed_scaling_factor=routed_scaling_factor, gemm1_alpha=gemm1_alpha, gemm1_limit=gemm1_limit, ) @perftest(num_warmup=5, num_iters=10,testGraph=False) def triton_fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, odtype, inplace: bool = False, activation: str = "silu", is_gated: Optional[bool] = None, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_int4_w4a8: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = 1.0, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None): return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, odtype, inplace, activation, is_gated, b1, b2, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int4_w4a8, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, no_combine, routed_scaling_factor, gemm1_alpha, gemm1_limit, ) @benchmark() def test_fmoe_w8a8_fused_shared_experts( dtype, token, model_dim, inter_dim, local_E, shared_E, local_topk, has_zp, quant="No", ): input = torch.randn((token, model_dim), dtype=dtype, device="cuda") / 10 # 1. original weight in fp16/fp32 w1_local = torch.randn((local_E, inter_dim * 2, model_dim), dtype=dtype, device="cuda") / 2 w2_local = torch.randn((local_E, model_dim, inter_dim), dtype=dtype, device="cuda") / 2 w1_shared = torch.randn((shared_E, inter_dim * 2, model_dim), dtype=dtype, device="cuda") / 2 w2_shared = torch.randn((shared_E, model_dim, inter_dim), dtype=dtype, device="cuda") / 2 w1_local_ref = w1_local.clone() w2_local_ref = w2_local.clone() w1_shared_ref = w1_shared.clone() w2_shared_ref = w2_shared.clone() w1_total_ref = torch.cat([w1_local, w1_shared], dim=0) # (local_E + shared_E, inter_dim * 2, model_dim) w2_total_ref = torch.cat([w2_local, w2_shared], dim=0) # (local_E + shared_E, model_dim, inter_dim) # 2. quantize weight to int8 max_vals = torch.abs(w1_local.to(torch.float32)).max(dim=-1, keepdim=True)[0] max_vals = max_vals.clamp(min=1e-5) w1_local_scales = max_vals / 127.0 w1_local_qweight = (w1_local / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8) max_vals = torch.abs(w2_local.to(torch.float32)).max(dim=-1, keepdim=True)[0] max_vals = max_vals.clamp(min=1e-5) w2_local_scales = max_vals / 127.0 w2_local_qweight = (w2_local / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8) max_vals = torch.abs(w1_shared.to(torch.float32)).max(dim=-1, keepdim=True)[0] max_vals = max_vals.clamp(min=1e-5) w1_shared_scales = max_vals / 127.0 w1_shared_qweight = (w1_shared / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8) max_vals = torch.abs(w2_shared.to(torch.float32)).max(dim=-1, keepdim=True)[0] max_vals = max_vals.clamp(min=1e-5) w2_shared_scales = max_vals / 127.0 w2_shared_qweight = (w2_shared / max_vals * 127.0).round().clamp(min=-128, max=127).to(torch.int8) w1_total_qweight = torch.cat([w1_local_qweight, w1_shared_qweight], dim=0) w2_total_qweight = torch.cat([w2_local_qweight, w2_shared_qweight], dim=0) w1_total_scales = torch.cat([w1_local_scales, w1_shared_scales], dim=0) w2_total_scales = torch.cat([w2_local_scales, w2_shared_scales], dim=0) # 3. generate score for local experts local_score = torch.randn((token, local_E), device="cuda", dtype=dtype) # Add small unique offset to each expert to ensure unique scores local_score += torch.arange(local_E, device="cuda", dtype=dtype).unsqueeze(0) * 1e-5 # 目前 biased_grouped_topk 接口暂不支持softmax方案,而是为local_score执行sigmoid,故is_softmax=False local_topk_weights, local_topk_ids = fused_topk(input, local_score, local_topk, True, is_softmax=False) # 4. local expert golden output torch_local_out, torch_local_avg = torch_moe_test( input, w1_local_ref, w2_local_ref, local_topk_weights, local_topk_ids ) # 5. sharded expert golden output. 若多个shared expert结果直接相加,等价于多个shared expert的topk weight相同,即fused_topk_weights[:, local_topk:] = 1.0 SHARED_EXPERT_WEIGHT = 1.0 torch_shared_out = torch.zeros((token, model_dim), dtype=dtype, device=input.device) for e in range(shared_E): gate_up_e = input @ w1_shared_ref[e].transpose(0, 1) # (token, inter_dim * 2) silu_out_e = torch.empty((token, inter_dim), dtype=dtype, device=input.device) silu_and_mul(silu_out_e, gate_up_e) out_e = silu_out_e @ w2_shared_ref[e].transpose(0, 1) # (token, model_dim) torch_shared_out += (out_e * SHARED_EXPERT_WEIGHT) # 6. local + shared output torch_total_out = torch_local_out + torch_shared_out # 7. get fused topk ids and weights # 7.1 local method total_topk = local_topk + shared_E fused_topk_ids = torch.empty(token, total_topk, dtype=dtypes.i32, device=input.device) fused_topk_ids[:, :local_topk] = local_topk_ids # 在local_topk后面补上shared expert的id fused_topk_ids[:, local_topk:] = torch.arange(local_E, local_E + shared_E, device=input.device).unsqueeze(0).expand(token, -1) fused_topk_weights = torch.empty(token, total_topk, dtype=dtypes.fp32, device=input.device) fused_topk_weights[:, :local_topk] = local_topk_weights fused_topk_weights[:, local_topk:] = SHARED_EXPERT_WEIGHT # 7.2 fused shared expert topk method fused_topk_ids_0 = torch.empty(token, total_topk, dtype=dtypes.i32, device=input.device) fused_topk_weights_0 = torch.empty(token, total_topk, dtype=dtypes.fp32, device=input.device) aiter.biased_grouped_topk( gating_output = local_score, correction_bias = torch.zeros(local_E, device=input.device), # bias for all experts, zeros topk_weights = fused_topk_weights_0, topk_ids = fused_topk_ids_0, num_expert_group = 1, topk_group = 1, num_fused_shared_experts = shared_E, need_renorm = True, routed_scaling_factor = 1.0, ) # print(f"topk_ids:\n{fused_topk_ids}\nfused_topk_ids_0:\n{fused_topk_ids_0}") msg = f"[Fused shared expert topk ids] {token=}, {local_E=}, {local_topk=}, {shared_E=}" checkAllclose(fused_topk_ids, fused_topk_ids_0, rtol=0.01, atol=0.01, msg=msg) # print(f"fused_topk_weights:\n{fused_topk_weights}\nfused_topk_weights_0:\n{fused_topk_weights_0}") msg = f"[Fused shared expert topk weights] {token=}, {local_E=}, {local_topk=}, {shared_E=}" checkAllclose(fused_topk_weights, fused_topk_weights_0, rtol=0.01, atol=0.01, msg=msg) # 8. torch fused shared expert output torch_fused_out, torch_fused_avg = torch_moe_test( input, w1_total_ref, w2_total_ref, fused_topk_weights, fused_topk_ids ) # 9. 验证fused experts算法正确性 msg = f"[Fused shared expert logit] {token=}, {model_dim=}, {inter_dim=}, {local_E=}, {shared_E=}, {local_topk=}, dtype: {dtype}, torch_fused_avg: {torch_fused_avg:<8.2f} us" checkAllclose(torch_total_out, torch_fused_out, rtol=0.01, atol=10, msg=msg) # 10. triton fused shared expert output input_q, input_scale = per_token_quant_hip(input,quant_dtype=torch.int8) triton_output, avg_triton = triton_fused_experts_impl( input_q, w1_total_qweight, w2_total_qweight, fused_topk_weights_0, fused_topk_ids_0, dtype, inplace=False, activation="silu", use_fp8_w8a8=False, use_int8_w8a8=True, use_int8_w8a16=False, use_int4_w4a16=False, use_int4_w4a8=False, per_channel_quant=True, global_num_experts=local_E + shared_E, expert_map=None, w1_scale=w1_total_scales, w2_scale=w2_total_scales, w1_zp=None, w2_zp=None, a1_scale=input_scale, a2_scale=None, block_shape=None) msg = f"[TRITON_fused_test] {token=}, {model_dim=}, {inter_dim=}, {local_E=}, {shared_E=}, {local_topk=}, dtype: {dtype}, triton_avg: {avg_triton:>8.2f} us, torch_fused_avg: {torch_fused_avg:>8.2f} us" checkAllclose(torch_total_out, triton_output, rtol=0.01, atol=50, msg=msg) return { "local_avg_torch": torch_local_avg, "avg_triton": avg_triton, } df = [] for dtype in [dtypes.fp16]: for m in [1, 16]: for dim in [7168]: for hdim in [256]: for sh_e in [1, 2]: ret = test_fmoe_w8a8_fused_shared_experts(dtype, m, dim, hdim, 128, sh_e, 8, False, quant="int8") df.append(ret) df = pd.DataFrame(df) # df.to_csv("test_moe_w8a8_fused_shared_experts.csv") aiter.logger.info(f"summary:\n{df}")