// SPDX-License-Identifier: MIT #include #include #include #include #include #include "py_itfs_common.h" #include "fused_moe.hpp" #include "moe_quant.hpp" #include "fused_moe_2stage.hpp" torch::Tensor ck_moe(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional w2_scale, // [e, 1, k], down scale std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional solution_id = 0, // solution id std::optional expert_mask = std::nullopt) { const at::cuda::OptionalCUDAGuard device_guard(device_of(hidden_states)); auto device = hidden_states.device(); int topk_ids_numel = topk_ids.numel(); int experts = w1.size(0); int topk = topk_ids.size(1); int tokens = topk_ids.size(0); int hidden_size = w1.size(2); int shared_intermediate_size_0 = w1.size(1); int shared_intermediate_size = w2.size(-1); int block_size = block_m.value(); int fused_quant = 0; if (!w1_scale.has_value()) { fused_quant = 0; } else if (a1_scale.has_value() && a2_scale.has_value()) { fused_quant = 1; } else if (use_int8_w8a16.has_value() && use_int8_w8a16.value()) { fused_quant = 2; } else if (use_int4_w4a16.has_value() && use_int4_w4a16.value()) { fused_quant = 3; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else if (use_int8_w8a8_block.has_value() && use_int8_w8a8_block.value()) { fused_quant = 4; } else if (use_int4_w4a8_block.has_value() && use_int4_w4a8_block.value()) { fused_quant = 5; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else { fused_quant = 10; } int gate_only = 1; int activation = 0; if (shared_intermediate_size_0 == 2 * shared_intermediate_size) { gate_only = 0; activation = 1; } int max_num_tokens_padded = topk_ids_numel + experts * block_size - topk; int max_num_m_blocks = (max_num_tokens_padded + block_size - 1) / block_size; auto sorted_ids = torch::empty({max_num_tokens_padded}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto sorted_weights = torch::empty({max_num_tokens_padded}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); auto sorted_expert_ids = torch::empty({max_num_m_blocks}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto tokens_positions_per_expert = torch::empty({experts*2}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto num_valid_ids = torch::empty({1}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto out = torch::zeros({tokens, hidden_size}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); auto prec_i = torchDTypeToStr(hidden_states.dtype()); auto prec_w = torchDTypeToStr(w1.dtype()); auto prec_o = torchDTypeToStr(out.dtype()); auto prec_kw = torchDTypeToStr(topk_weights.dtype()); int stride = hidden_size; std::string prec_st = !a1_scale ? "fp32" : torchDTypeToStr(a1_scale->dtype()); std::string prec_sw = !w1_scale ? "fp32" : torchDTypeToStr(w1_scale->dtype()); std::string prec_sq = !a2_scale ? "fp32" : torchDTypeToStr(a2_scale->dtype()); std::string prec_zp = !w1_zp ? "uint8" : torchDTypeToStr(w1_zp->dtype()); // TODO: after moe_sorting patch done, enable topk int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); // int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); void *ws_ptr = nullptr; if (workspace_size > 0) { auto ws = torch::zeros({workspace_size}, torch::TensorOptions().dtype(topk_ids.dtype()).device(device_of(topk_ids))); ws_ptr = ws.data_ptr(); } fused_moe_traits traits{ prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_kw, prec_zp, block_size, activation, gate_only, fused_quant, solution_id.has_value() ? solution_id.value() : 0, false, // use_wt_shuffle expert_mask.has_value(), }; fused_moe_args args{hidden_states.data_ptr(), a1_scale.has_value() ? a1_scale.value().data_ptr() : nullptr, w1.data_ptr(), w2.data_ptr(), w1_scale.has_value() ? w1_scale.value().data_ptr() : nullptr, w2_scale.has_value() ? w2_scale.value().data_ptr() : nullptr, w1_zp.has_value() ? w1_zp.value().data_ptr() : nullptr, w2_zp.has_value() ? w2_zp.value().data_ptr() : nullptr, a2_scale.has_value() ? a2_scale.value().data_ptr() : nullptr, expert_mask.has_value() ? expert_mask.value().data_ptr() : nullptr, out.data_ptr(), ws_ptr, topk_ids.data_ptr(), topk_weights.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), tokens_positions_per_expert.data_ptr(), num_valid_ids.data_ptr(), block_size, hidden_size, shared_intermediate_size, tokens, experts, topk, stride, block_shape_n.has_value() ? block_shape_n.value() : 0, block_shape_k.has_value() ? block_shape_k.value() : 0}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); fused_moe(traits, args, {stream}); return out; } torch::Tensor ck_shuffle_moe(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional solution_id = 0, // solution id std::optional expert_mask = std::nullopt) { const at::cuda::OptionalCUDAGuard device_guard(device_of(hidden_states)); auto device = hidden_states.device(); int topk_ids_numel = topk_ids.numel(); int experts = w1.size(0); int topk = topk_ids.size(1); int tokens = topk_ids.size(0); int hidden_size = hidden_states.size(1); int shared_intermediate_size_0 = w1.size(1) * 128; // need to change 128 to a variable int shared_intermediate_size = w2.size(1) * 128; // need to change 128 to a variable int block_size = block_m.value(); int fused_quant = 0; if (!w1_scale.has_value()) { fused_quant = 0; } else if (a1_scale.has_value() && a2_scale.has_value()) { fused_quant = 1; } else if (use_int8_w8a16.has_value() && use_int8_w8a16.value()) { fused_quant = 2; } else if (use_int4_w4a16.has_value() && use_int4_w4a16.value()) { fused_quant = 3; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else if (use_int8_w8a8_block.has_value() && use_int8_w8a8_block.value()) { // out_dtype = in_dtype.toScalarType(); fused_quant = 4; } else if (use_int4_w4a8_block.has_value() && use_int4_w4a8_block.value()) { fused_quant = 5; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else { fused_quant = 10; } int gate_only = 1; int activation = 0; if (shared_intermediate_size_0 == 2 * shared_intermediate_size) { gate_only = 0; activation = 1; } int max_num_tokens_padded = topk_ids_numel + experts * block_size - topk; int max_num_m_blocks = (max_num_tokens_padded + block_size - 1) / block_size; auto sorted_ids = torch::empty({max_num_tokens_padded}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto sorted_weights = torch::empty({max_num_tokens_padded}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); auto sorted_expert_ids = torch::empty({max_num_m_blocks}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto tokens_positions_per_expert = torch::empty({experts*2}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto num_valid_ids = torch::empty({1}, torch::TensorOptions().dtype(torch::kInt32).device(device)); auto out = torch::zeros({tokens, hidden_size}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); auto prec_i = torchDTypeToStr(hidden_states.dtype()); auto prec_w = torchDTypeToStr(w1.dtype()); auto prec_o = torchDTypeToStr(out.dtype()); auto prec_kw = torchDTypeToStr(topk_weights.dtype()); int stride = hidden_size; std::string prec_st = !a1_scale ? "fp32" : torchDTypeToStr(a1_scale->dtype()); std::string prec_sw = !w1_scale ? "fp32" : torchDTypeToStr(w1_scale->dtype()); std::string prec_sq = !a2_scale ? "fp32" : torchDTypeToStr(a2_scale->dtype()); std::string prec_zp = !w1_zp ? "uint8" : torchDTypeToStr(w1_zp->dtype()); // TODO: after moe_sorting patch done, enable topk int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); // int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); void *ws_ptr = nullptr; if (workspace_size > 0) { auto ws = torch::zeros({workspace_size}, torch::TensorOptions().dtype(topk_ids.dtype()).device(device_of(topk_ids))); ws_ptr = ws.data_ptr(); } fused_moe_traits traits{ prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_kw, prec_zp, block_size, activation, gate_only, fused_quant, solution_id.has_value() ? solution_id.value() : 0, true, // use_wt_shuffle expert_mask.has_value(), }; fused_moe_args args{hidden_states.data_ptr(), a1_scale.has_value() ? a1_scale.value().data_ptr() : nullptr, w1.data_ptr(), w2.data_ptr(), w1_scale.has_value() ? w1_scale.value().data_ptr() : nullptr, w2_scale.has_value() ? w2_scale.value().data_ptr() : nullptr, w1_zp.has_value() ? w1_zp.value().data_ptr() : nullptr, w2_zp.has_value() ? w2_zp.value().data_ptr() : nullptr, a2_scale.has_value() ? a2_scale.value().data_ptr() : nullptr, expert_mask.has_value() ? expert_mask.value().data_ptr() : nullptr, out.data_ptr(), ws_ptr, topk_ids.data_ptr(), topk_weights.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), tokens_positions_per_expert.data_ptr(), num_valid_ids.data_ptr(), block_size, hidden_size, shared_intermediate_size, tokens, experts, topk, stride, block_shape_n.has_value() ? block_shape_n.value() : 0, block_shape_k.has_value() ? block_shape_k.value() : 0}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); fused_moe(traits, args, {stream}); return out; } std::vector ck_moe_get_solutions(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_int4_w4a8_block,// use int4 w4a8 block quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional expert_mask = std::nullopt) { int experts = w1.size(0); int topk = topk_ids.size(1); int tokens = topk_ids.size(0); int hidden_size = w1.size(2); int shared_intermediate_size_0 = w1.size(1); int shared_intermediate_size = w2.size(-1); int block_size = block_m.has_value() ? block_m.value() : 0; int fused_quant = 0; if (!w1_scale.has_value()) { fused_quant = 0; } else if (a1_scale.has_value() && a2_scale.has_value()) { fused_quant = 1; } else if (use_int8_w8a16.has_value() && use_int8_w8a16.value()) { fused_quant = 2; } else if (use_int4_w4a16.has_value() && use_int4_w4a16.value()) { fused_quant = 3; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else if (use_int8_w8a8_block.has_value() && use_int8_w8a8_block.value()) { fused_quant = 4; } else if (use_int4_w4a8_block.has_value() && use_int4_w4a8_block.value()) { fused_quant = 5; hidden_size = w2.size(-2); // w2 shape = (e, k, n // 2) shared_intermediate_size = w2.size(-1) * 2; // w2 shape = (e, k, n // 2) shared_intermediate_size_0 = w1.size(-2) ; // w1 shape = (e, 2 * n, k // 2) } else { fused_quant = 10; } int gate_only = 1; int activation = 0; if (shared_intermediate_size_0 == 2 * shared_intermediate_size) { gate_only = 0; activation = 1; } auto prec_i = torchDTypeToStr(hidden_states.dtype()); auto prec_w = torchDTypeToStr(w1.dtype()); auto prec_kw = torchDTypeToStr(topk_weights.dtype()); int stride = hidden_size; std::string prec_st = !a1_scale ? "fp32" : torchDTypeToStr(a1_scale->dtype()); std::string prec_sw = !w1_scale ? "fp32" : torchDTypeToStr(w1_scale->dtype()); std::string prec_sq = !a2_scale ? "fp32" : torchDTypeToStr(a2_scale->dtype()); std::string prec_zp = !w1_zp ? "uint8" : torchDTypeToStr(w1_zp->dtype()); std::string prec_o = "fp32"; fused_moe_traits traits{ prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_kw, prec_zp, block_size, activation, gate_only, fused_quant, 0, expert_mask.has_value(), }; fused_moe_args args{hidden_states.data_ptr(), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, //out nullptr, //ws_ptr nullptr, nullptr, //sorted_ids nullptr, //sorted_weights nullptr, //sorted_expert_ids nullptr, //tokens_positions_per_expert nullptr, //num_valid_ids block_size, hidden_size, shared_intermediate_size, tokens, experts, topk, stride, block_shape_n.has_value() ? block_shape_n.value() : 0, block_shape_k.has_value() ? block_shape_k.value() : 0}; int solution_size = 0; fused_moe_get_solutions(traits, args, nullptr, &solution_size); //get solution size std::vector solutionsSolve(solution_size); fused_moe_get_solutions(traits, args, solutionsSolve.data(), &solution_size); //get solutions std::vector validSolutions; for (auto sol : solutionsSolve) { if (true) { //check sol is valid validSolutions.push_back(sol); } } return validSolutions; } void ck_moe_per_token_quant(torch::Tensor &input, // [m, k], input token torch::Tensor &out_quant, // [m, k], output token torch::Tensor &out_scale) // [m, 1], output scale { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); std::string prec_in = torchDTypeToStr(input.dtype()); std::string prec_out = torchDTypeToStr(out_quant.dtype()); ck_tile::index_t tokens = static_cast(input.size(0)); ck_tile::index_t hidden_size = static_cast(input.size(1)); ck_tile::index_t x_stride = static_cast(input.stride(1)); ck_tile::index_t y_stride = static_cast(out_quant.stride(0)); // printf("ck_moe_per_token_quant: x_stride: %d, y_stride: %d\n", x_stride, y_stride); // printf("ck_moe_per_token_quant: prec_in: %s, prec_out: %s\n", prec_in.c_str(), prec_out.c_str()); moe_quant_traits quant_traits{ prec_in, prec_out}; moe_quant_args moe_quant_args{ input.data_ptr(), out_scale.data_ptr(), out_quant.data_ptr(), tokens, // m hidden_size, // k x_stride, // stride y_stride, // stride }; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); moe_quant(quant_traits, moe_quant_args, {stream}); } void ck_moe_stage_1(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, 2*n, k] torch::Tensor &w2, // [e, k, n] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &tokens_positions_per_expert, // [num_experts*2] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_fp8_w8a8_block, // use fp8 w8a8 block quantization std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional a1_scale, // [m, 1], token scale std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m, std::optional sorted_weights, std::optional act_op) { const at::cuda::OptionalCUDAGuard device_guard(device_of(hidden_states)); auto device = hidden_states.device(); int tokens = hidden_states.size(0); int experts = w1.size(0); int hidden_size = w1.size(2); int shared_intermediate_size_0 = w1.size(1); int shared_intermediate_size = w2.size(-1); int block_size = block_m.has_value() ? block_m.value() : 16; int gate_only = 1; int activation = 0; if (shared_intermediate_size_0 == 2 * shared_intermediate_size) { gate_only = 0; activation = 1; } fused_quant_mode fused_quant = fused_quant_mode::none; if (!w1_scale.has_value()) { fused_quant = fused_quant_mode::none; } else if (use_int8_w8a8_block.has_value() && use_int8_w8a8_block.value()) { fused_quant = fused_quant_mode::int8_w8a8_block; } else if (use_fp8_w8a8_block.has_value() && use_fp8_w8a8_block.value()) { fused_quant = fused_quant_mode::fp8_w8a8_block; } auto prec_i = torchDTypeToStr(hidden_states.dtype()); auto prec_w = torchDTypeToStr(w1.dtype()); auto prec_o = torchDTypeToStr(out.dtype()); int stride = hidden_size; std::string prec_st = !a1_scale ? "fp32" : torchDTypeToStr(a1_scale->dtype()); std::string prec_sw = !w1_scale ? "fp32" : torchDTypeToStr(w1_scale->dtype()); std::string prec_sq = "fp32"; std::string prec_zp = "uint8"; fused_moe_2stage_traits traits{ prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_zp, block_size, activation, gate_only, fused_quant, 0, // solution id false }; fused_moe_stage1_args args{ hidden_states.data_ptr(), a1_scale.has_value() ? a1_scale.value().data_ptr() : nullptr, w1.data_ptr(), w1_scale.has_value() ? w1_scale.value().data_ptr() : nullptr, nullptr, nullptr, // local_expert_mask_ptr out.data_ptr(), sorted_token_ids.data_ptr(), sorted_weights.has_value() ? sorted_weights.value().data_ptr() : nullptr, sorted_expert_ids.data_ptr(), tokens_positions_per_expert.data_ptr(), num_valid_ids.data_ptr(), block_size, hidden_size, shared_intermediate_size, tokens, experts, topk, stride, block_shape_n.has_value() ? block_shape_n.value() : 0, block_shape_k.has_value() ? block_shape_k.value() : 0 }; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); fused_moe_stage1(traits, args, {stream}); } void ck_moe_stage_2(torch::Tensor &inter_states, // [m, topk, n], input states torch::Tensor &w1, // [e, 2*n, k] torch::Tensor &w2, // [e, k, n] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &tokens_positions_per_expert, // [num_experts*2] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional use_int8_w8a8_block,// use int8 w8a8 block quantization std::optional use_fp8_w8a8_block, // use fp8 w8a8 block quantization std::optional w2_scale, // [e, 1, n], gate(up) scale std::optional a2_scale, // [m, 1], token scale std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m, std::optional sorted_weights) // [max_num_tokens_padded]) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inter_states)); auto device = inter_states.device(); int tokens = inter_states.size(0) / topk; int experts = w2.size(0); int hidden_size = w2.size(1); int shared_intermediate_size_0 = w1.size(1); int shared_intermediate_size = w2.size(-1); int block_size = block_m.has_value() ? block_m.value() : 16; int gate_only = 1; int activation = 0; if (shared_intermediate_size_0 == 2 * shared_intermediate_size) { gate_only = 0; activation = 1; } fused_quant_mode fused_quant = fused_quant_mode::none; if (!w2_scale.has_value()){ fused_quant = fused_quant_mode::none; } else if (use_int8_w8a8_block.has_value() && use_int8_w8a8_block.value()) { fused_quant = fused_quant_mode::int8_w8a8_block; } else if (use_fp8_w8a8_block.has_value() && use_fp8_w8a8_block.value()) { fused_quant = fused_quant_mode::fp8_w8a8_block; } auto prec_i = torchDTypeToStr(inter_states.dtype()); auto prec_w = torchDTypeToStr(w2.dtype()); auto prec_o = torchDTypeToStr(out.dtype()); int stride = hidden_size; std::string prec_st = !a2_scale ? "fp32" : torchDTypeToStr(a2_scale->dtype()); std::string prec_sw = !w2_scale ? "fp32" : torchDTypeToStr(w2_scale->dtype()); std::string prec_sq = "fp32"; std::string prec_zp = "uint8"; fused_moe_2stage_traits traits{ prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_zp, block_size, activation, gate_only, fused_quant, 0, // solution id false }; fused_moe_stage2_args args{ inter_states.data_ptr(), a2_scale.has_value() ? a2_scale.value().data_ptr() : nullptr, w2.data_ptr(), w2_scale.has_value() ? w2_scale.value().data_ptr() : nullptr, nullptr, nullptr, // local_expert_mask_ptr out.data_ptr(), sorted_token_ids.data_ptr(), sorted_weights.has_value() ? sorted_weights.value().data_ptr() : nullptr, sorted_expert_ids.data_ptr(), tokens_positions_per_expert.data_ptr(), num_valid_ids.data_ptr(), block_size, hidden_size, shared_intermediate_size, tokens, experts, topk, stride, block_shape_n.has_value() ? block_shape_n.value() : 0, block_shape_k.has_value() ? block_shape_k.value() : 0 }; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); fused_moe_stage2(traits, args, {stream}); }