// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include torch::Tensor moe_c_moe_gemm_marlin_w8a8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ); torch::Tensor moe_c_moe_gemm_marlin_w4a8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ); torch::Tensor moe_c_moe_gemm_marlin_w8a8_fp8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ); torch::Tensor moe_c_moe_gemm_marlin_w4a16(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor b_scale, torch::Tensor b_zeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ); torch::Tensor moe_c_moe_w8a8_gemm_block_wise(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) ; torch::Tensor moe_c_moe_w8a8_gemm_block_wise_kernel2(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ); torch::Tensor moe_c_moe_w8a8_gemm_block_wise_fp8(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ); torch::Tensor moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ); torch::Tensor moe_c_moe_w8a16_gemm_awq(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit); torch::Tensor moe_c_moe_w8a16_gemm_block_wise(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit); torch::Tensor moe_c_moe_wna16_gemm_base(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); torch::Tensor moe_c_moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) ; torch::Tensor moe_c_moe_wna16_gemm_2(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) ; void moe_c_silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) ; // [..., 2 * d] void moe_c_topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& gating_output); void moe_c_moe_sum(torch::Tensor& input, torch::Tensor& output, torch::Tensor topk_ids); void moe_c_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); void moe_c_sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad);