// SPDX-License-Identifier: MIT #pragma once #include namespace aiter { void mhc_pre_gemm_sqrsum(torch::Tensor& out, // (split_k, m, hc_mult3) / (m, hc_mult3) torch::Tensor& sqrsum, // (split_k, m) / (m) torch::Tensor& x, // (m, hc_hidden_size) torch::Tensor& fn, // (hc_mult3, hc_hidden_size) int tile_k = 128, bool use_tf32 = false); void mhc_pre_gemm_sqrsum_stage1_m128(torch::Tensor& out, // (split_k, m, hc_mult3) / (m, hc_mult3) torch::Tensor& sqrsum, // (split_k, m) / (m) torch::Tensor& x, // (m, hc_hidden_size) torch::Tensor& fn, // (hc_mult3, hc_hidden_size) bool use_tf32 = false ); void mhc_pre_reduce_splitk(torch::Tensor& out_red, // (1, m, hc_mult3) torch::Tensor& sqrsum_red, // (1, m) torch::Tensor& out, // (split_k, m, hc_mult3) torch::Tensor& sqrsum // (split_k, m) ); void mhc_pre_big_fuse(torch::Tensor& post_mix, // (m, hc_mult) torch::Tensor& comb_mix, // (m, hc_mult * hc_mult) torch::Tensor& layer_input, // (m, hidden_size) torch::Tensor& gemm_out_mul, // (split_k, m, hc_mult3) torch::Tensor& gemm_out_sqrsum, // (split_k, m) torch::Tensor& hc_scale, // (3) torch::Tensor& hc_base, // (hc_mult3) torch::Tensor& residual, // (m, hc_mult, hidden_size) float rms_eps = 1e-6, float hc_pre_eps = 1e-6, float hc_sinkhorn_eps = 1e-6, float hc_post_mult_value = 1.0, int sinkhorn_repeat = 20); void mhc_pre_big_fuse_tlstyle(torch::Tensor& post_mix, // (m, hc_mult) torch::Tensor& comb_mix, // (m, hc_mult * hc_mult) torch::Tensor& layer_input, // (m, hidden_size) torch::Tensor& gemm_out_mul, // (split_k, m, hc_mult3) torch::Tensor& gemm_out_sqrsum, // (split_k, m) torch::Tensor& hc_scale, // (3) torch::Tensor& hc_base, // (hc_mult3) torch::Tensor& residual, // (m, hc_mult, hidden_size) float rms_eps = 1e-6, float hc_pre_eps = 1e-6, float hc_sinkhorn_eps = 1e-6, float hc_post_mult_value = 1.0, int sinkhorn_repeat = 20); void mhc_post(torch::Tensor& out, // (m, hc_mult, hidden_size) torch::Tensor& x, // (m, hidden_size) torch::Tensor& residual, // (m, hc_mult, hidden_size) torch::Tensor& post_layer_mix, // (m, hc_mult) torch::Tensor& comb_res_mix // (m, hc_mult, hc_mult) ); } // namespace aiter