#pragma once #include #include #include "aiter_enum.h" // void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, // torch::Tensor& token_expert_indices, // torch::Tensor& gating_output); // void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); // #ifndef USE_ROCM // torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, // torch::Tensor b_qweight, torch::Tensor b_scales, // std::optional b_qzeros, // std::optional topk_weights, // torch::Tensor sorted_token_ids, // torch::Tensor expert_ids, // torch::Tensor num_tokens_post_pad, int64_t top_k, // int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, // int64_t BLOCK_SIZE_K, int64_t bit); // #endif // std::vector moe_fused_gate( // torch::Tensor& input, // torch::Tensor& bias, // int64_t num_expert_group, // int64_t topk_group, // int64_t topk, // int64_t n_share_experts_fusion, // double routed_scaling_factor);