#pragma once // SPDX-License-Identifier: MIT #include #include "aiter_enum.h" void fmoe(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk // ); void fmoe_int8_g1u0(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc2_smooth_scale, // [expert, 1, hidden_dim] ActivationType activation = ActivationType::Silu); void fmoe_g1u1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim*2, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] std::optional fc2_smooth_scale = std::nullopt, // [expert, 1, hidden_dim] ActivationType activation = ActivationType::Silu); void fmoe_g1u1_tkw1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim*2, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] std::optional fc2_smooth_scale = std::nullopt, // [expert, 1, hidden_dim] ActivationType activation = ActivationType::Silu); void fmoe_int8_g1u0_a16(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &fc1_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc1_smooth_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_smooth_scale // [expert, 1, hidden_dim] ); void fmoe_g1u1_a16(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &fc1_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc1_smooth_scale, // [expert, 1, hidden_dim] torch::Tensor &fc2_smooth_scale // [expert, 1, hidden_dim] ); void fmoe_fp8_blockscale_g1u1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim*2, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [expert, 1, dim] torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] int fc_scale_blkn = 128, // = 128, int fc_scale_blkk = 128, // = 128 std::optional fc2_smooth_scale = std::nullopt, // [expert, 1, inter_dim] ActivationType activation = ActivationType::Silu); void moe_stage1_g1u1(torch::Tensor &input, // [token_cnt, model_dim] M,K torch::Tensor &w1, // [expert, inter_dim*2, model_dim] N,K torch::Tensor &w2, // [expert, model_dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [token_cnt, topk, inter_dim] int inter_dim, std::string &kernelName, int block_m, int ksplit, ActivationType activation, QuantType quant_type, std::optional a1_scale, // [token_cnt, 1], token scale std::optional w1_scale, // [expert, 1, inter_dim], gate(up) scale std::optional sorted_weights);