import argparse import torch import torch.nn.functional as F from aiter.test_common import checkAllclose from aiter.ops.triton.fused_moe import fused_experts_impl def torch_fused_experts_relu2( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape topk = topk_ids.shape[1] out = torch.zeros( (num_tokens, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device, ) for token_idx in range(num_tokens): token_out = torch.zeros( (hidden_size,), dtype=hidden_states.dtype, device=hidden_states.device, ) x = hidden_states[token_idx] for k in range(topk): expert_id = int(topk_ids[token_idx, k].item()) weight = topk_weights[token_idx, k] up = x @ w1[expert_id].transpose(0, 1) act = torch.square(F.relu(up)) down = act @ w2[expert_id].transpose(0, 1) token_out = token_out + weight * down out[token_idx] = token_out return out def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=32) parser.add_argument("--hidden_size", type=int, default=128) parser.add_argument("--intermediate_size", type=int, default=256) parser.add_argument("--num_experts", type=int, default=4) parser.add_argument("--topk", type=int, default=1) parser.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16"]) args = parser.parse_args() torch.manual_seed(0) dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 device = "cuda" hidden_states = torch.randn(args.m, args.hidden_size, dtype=dtype, device=device) / 10 w1 = torch.randn( args.num_experts, args.intermediate_size, args.hidden_size, dtype=dtype, device=device, ) / 10 w2 = torch.randn( args.num_experts, args.hidden_size, args.intermediate_size, dtype=dtype, device=device, ) / 10 scores = torch.randn(args.m, args.num_experts, dtype=torch.float32, device=device) topk_weights, topk_ids = torch.topk(scores, k=args.topk, dim=-1) topk_weights = torch.softmax(topk_weights, dim=-1).to(torch.float32) topk_ids = topk_ids.to(torch.int32) ref = torch_fused_experts_relu2(hidden_states, w1, w2, topk_weights, topk_ids) out = fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, dtype, False, activation="relu2", ) err = checkAllclose(ref, out, rtol=2e-2, atol=2e-2, printLog=True) if err: raise AssertionError(f"relu2 fused_experts mismatch, err_ratio={err}") print("relu2 fused_experts test passed") if __name__ == "__main__": main()