// SPDX-License-Identifier: MIT #pragma once #include torch::Tensor ck_moe(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional w2_scale, // [e, 1, k], down scale std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional solution_id = 0, // solution id std::optional expert_mask = std::nullopt); torch::Tensor ck_shuffle_moe(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional solution_id = 0, // solution id std::optional expert_mask = std::nullopt); std::vector ck_moe_get_solutions(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional expert_mask = std::nullopt); void ck_moe_per_token_quant(torch::Tensor &input, // [m, k], input token torch::Tensor &out_quant, // [m, k], output token torch::Tensor &out_scale); // [m, 1], output scale void ck_moe_stage_1(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &tokens_positions_per_expert, // [num_experts*2] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_fp8_w8a8_block, // use fp8 w8a8 block quantization std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional a1_scale, // [m, 1], token scale std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m, std::optional sorted_weights, std::optional act_op); void ck_moe_stage_2(torch::Tensor &inter_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &tokens_positions_per_expert, // [num_experts*2] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_fp8_w8a8_block, // use fp8 w8a8 block quantization std::optional w2_scale, // [e, 1, n], gate(up) scale std::optional a2_scale, // [m, 1], token scale std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m, std::optional sorted_weights); // [max_num_tokens_padded]);