// #include // #include // #include // #include // #include // #include // #include "moe_wna16_utils.h" // #include "moe_ops.h" #pragma once #include "moe_w8a16_block_wise.h" #include "moe_w8a16_awq.h" #include "moe_w4a16.h" #include "moe_w4a16_2.h" #include "moe_w4a16_base.h" #include "moe_w8a8_block_wise.h" #include "moe_w8a8_block_wise_kernel2.h" #include "moe_w8a8_block_wise_fp8.h" #include "moe_w8a8_block_wise_kernel2_fp8.h" #include "topk_softmax_kernel.h" #include "moe_align_sum_kernels.h" #include "moe_sum_kernels_opt_v2.h" #include "silu_mul_kernels.h" #include #include #include #include #include // #include "moe_w8a8_utils.h" // #include "moe_w8a8_config.h" #include "moe_w8a8_opt.h" #include "moe_w4a16_opt.h" #include "moe_w4a8_opt.h" #include "moe_w8a16_chan_opt.h" #define BIT_SWITCH(bit, BIT, ...) \ [&] { \ if (bit == 8) { \ constexpr static int BIT = 8; \ return __VA_ARGS__(); \ }else if (bit == 4) { \ constexpr static int BIT = 4; \ return __VA_ARGS__(); \ } \ else { \ std::cout<<"unsupported BIT"< static void moe_marlin_w8a8_dispatch_gemm_stages( bool first_stage, const torch::Tensor& input, const torch::Tensor& b_qweight, torch::Tensor& output_alias, const torch::Tensor& a_scale, const torch::Tensor& b_scale, const float* topk_weights_ptr, const torch::Tensor& sorted_token_ids, const torch::Tensor& expert_ids, int num_pad, const torch::Tensor& num_tokens_post_pad, int size_m, int size_n, int size_k, int stride_asm, int stride_ask, int stride_bse, int stride_bsn, int stride_bsk, int64_t top_k, uint32_t real_topk, bool is_marlin, int64_t mode, int config_m, bool tensorwise_scale = false) { static_assert( std::is_same_v || std::is_same_v, "ElemT must be at::Float8_e4m3fn or int8_t"); if (first_stage) { const int64_t EM = sorted_token_ids.size(0); GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (OutT*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale); if (config_m <= 512) { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm1_decode_fp8.find(mode); if (it != kernel_maps_gemm1_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); } } else { auto it = kernel_maps_gemm1_decode.find(mode); if (it != kernel_maps_gemm1_decode.end()) { it->second(params_in); } else { if constexpr (std::is_same_v) { printf("half version gemm1 No matching kernel configuration found, using default settings \n"); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); } } } } else { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm1_prefill_fp8.find(mode); if (it != kernel_maps_gemm1_prefill_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); } } else { auto it = kernel_maps_gemm1_prefill.find(mode); if (it != kernel_maps_gemm1_prefill.end()) { it->second(params_in); } else { if constexpr (std::is_same_v) { printf("half version gemm1 No matching kernel configuration found \n"); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); } } } } } else { const int64_t EM = sorted_token_ids.size(0); GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (OutT*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale); if (config_m <= 512) { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm2_decode_fp8.find(mode); if (it != kernel_maps_gemm2_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } else { auto it = kernel_maps_gemm2_decode.find(mode); if (it != kernel_maps_gemm2_decode.end()) { it->second(params_in); } else { if constexpr (std::is_same_v) { printf("half version gemm2 No matching kernel configuration found, using default settings \n"); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } } } else { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm2_prefill_fp8.find(mode); if (it != kernel_maps_gemm2_prefill_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } else { auto it = kernel_maps_gemm2_prefill.find(mode); if (it != kernel_maps_gemm2_prefill.end()) { it->second(params_in); } else { if constexpr (std::is_same_v) { printf("half version gemm2 No matching kernel configuration found \n"); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } } } } } template static void moe_marlin_w8a8_dispatch_gemm_stages_n160( bool first_stage, const torch::Tensor& input, const torch::Tensor& b_qweight, torch::Tensor& output_alias, const torch::Tensor& a_scale, const torch::Tensor& b_scale, const float* topk_weights_ptr, const torch::Tensor& sorted_token_ids, const torch::Tensor& expert_ids, int num_pad, const torch::Tensor& num_tokens_post_pad, int size_m, int size_n, int size_k, int stride_asm, int stride_ask, int stride_bse, int stride_bsn, int stride_bsk, int64_t top_k, uint32_t real_topk, bool is_marlin, int64_t mode, int config_m, bool tensorwise_scale = false) { static_assert( std::is_same_v || std::is_same_v, "ElemT must be at::Float8_e4m3fn or int8_t"); if (first_stage) { const int64_t EM = sorted_token_ids.size(0); GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (OutT*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale); if (config_m <= 512) { if constexpr (std::is_same_v) { // std::cout<<"gemm1******************mode "<.find(mode); if (it != kernel_maps_gemm1_decode_n160_fp8.end()) { it->second(params_in); } else { std::cout<<"moe_marlin_w8a8_dispatch_gemm_stages mode \n"<.find(mode); // if (it != kernel_maps_gemm1_n160_decode.end()) { // it->second(params_in); // } else { // if constexpr (std::is_same_v) { // printf("half version gemm1 No matching kernel configuration found, using default settings \n"); // } else { // printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // } // } // } } else { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm1_prefill_n160_fp8.find(mode); if (it != kernel_maps_gemm1_prefill_n160_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); } } // else { // auto it = kernel_maps_gemm1_prefill_n160.find(mode); // if (it != kernel_maps_gemm1_prefill_n160.end()) { // it->second(params_in); // } else { // if constexpr (std::is_same_v) { // printf("half version gemm1 No matching kernel configuration found \n"); // } else { // printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // } // } // } } } else { const int64_t EM = sorted_token_ids.size(0); GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (OutT*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale); if (config_m <= 512 ) { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm2_decode_n160_fp8.find(mode); if (it != kernel_maps_gemm2_decode_n160_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } // else { // auto it = kernel_maps_gemm2_n160_decode.find(mode); // if (it != kernel_maps_gemm2_n160_decode.end()) { // it->second(params_in); // } else { // if constexpr (std::is_same_v) { // printf("half version gemm2 No matching kernel configuration found, using default settings \n"); // } else { // printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // } // } // } } else { if constexpr (std::is_same_v) { auto it = kernel_maps_gemm2_prefill_n160_fp8.find(mode); if (it != kernel_maps_gemm2_prefill_n160_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); } } // else { // auto it = kernel_maps_gemm2_prefill_n160.find(mode); // if (it != kernel_maps_gemm2_prefill_n160.end()) { // it->second(params_in); // } else { // if constexpr (std::is_same_v) { // printf("half version gemm2 No matching kernel configuration found \n"); // } else { // printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // } // } // } } } } // 模板抽象 // BLOCK_MNK: [16, 128, 128] // WARP_MNK: [16, 32, 64] // MMA_MNK: [16, 16, 32] // BLOCK_MNK / WARP_MNK = [1, 4, 2] 代表warp在MN方向的排布 block_k/warp_k代表stage // WARP_MNK / MMA_MNK = [1, 2, 2] 代表warp在MNK方向重复计算的次数 会分配额外的寄存器 static torch::Tensor moe_c_moe_gemm_marlin_w8a8_impl(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta, int64_t config_m, bool tensorwise_scale ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale ); if(config_m <= 512){ auto it = kernel_maps_gemm1_decode.find(mode); if (it != kernel_maps_gemm1_decode.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale ); if(config_m <= 512 ){ auto it = kernel_maps_gemm2_decode.find(mode); if (it != kernel_maps_gemm2_decode.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale ); if(config_m <= 512){ auto it = kernel_maps_gemm1_decode.find(mode); if (it != kernel_maps_gemm1_decode.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found \n"); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin, tensorwise_scale ); if(config_m <= 512 ){ auto it = kernel_maps_gemm2_decode.find(mode); if (it != kernel_maps_gemm2_decode.end()) { it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } return output; } torch::Tensor moe_c_moe_gemm_marlin_w8a8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t mode, int64_t delta, int64_t config_m ) { return moe_c_moe_gemm_marlin_w8a8_impl(input, b_qweight, output, a_scale, b_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_pad, top_k, mode, delta, config_m, false); } torch::Tensor moe_c_moe_gemm_marlin_w8a8_tensorwise(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t mode, int64_t delta, int64_t config_m ) { return moe_c_moe_gemm_marlin_w8a8_impl(input, b_qweight, output, a_scale, b_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_pad, top_k, mode, delta, config_m, true); } // 模板抽象 // BLOCK_MNK: [16, 128, 128] // WARP_MNK: [16, 32, 64] // MMA_MNK: [16, 16, 32] // BLOCK_MNK / WARP_MNK = [1, 4, 2] 代表warp在MN方向的排布 block_k/warp_k代表stage // WARP_MNK / MMA_MNK = [1, 2, 2] 代表warp在MNK方向重复计算的次数 会分配额外的寄存器 static torch::Tensor moe_c_moe_gemm_marlin_w8a8_fp8_impl(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta, int64_t config_m, bool tensorwise_scale ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< mode \n"<( first_stage, input, b_qweight, output_alias, a_scale, b_scale, topk_weights_ptr, sorted_token_ids, expert_ids, num_pad, num_tokens_post_pad, size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, top_k, real_topk, is_marlin, mode, config_m, tensorwise_scale); } else{ // std::cout<<"moe_marlin_w8a8_dispatch_gemm_stages_n160\n"; moe_marlin_w8a8_dispatch_gemm_stages( first_stage, input, b_qweight, output_alias, a_scale, b_scale, topk_weights_ptr, sorted_token_ids, expert_ids, num_pad, num_tokens_post_pad, size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, top_k, real_topk, is_marlin, mode, config_m, tensorwise_scale); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports fp8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Float8_e4m3fn){ if(EXPERTS == 288){ moe_marlin_w8a8_dispatch_gemm_stages_n160( first_stage, input, b_qweight, output_alias, a_scale, b_scale, topk_weights_ptr, sorted_token_ids, expert_ids, num_pad, num_tokens_post_pad, size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, top_k, real_topk, is_marlin, mode, config_m, tensorwise_scale); } else{ moe_marlin_w8a8_dispatch_gemm_stages( first_stage, input, b_qweight, output_alias, a_scale, b_scale, topk_weights_ptr, sorted_token_ids, expert_ids, num_pad, num_tokens_post_pad, size_m, size_n, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, top_k, real_topk, is_marlin, mode, config_m, tensorwise_scale); } // hipDeviceSynchronize(); } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports fp8"); } } return output; } torch::Tensor moe_c_moe_gemm_marlin_w8a8_fp8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t mode, int64_t delta, int64_t config_m ) { return moe_c_moe_gemm_marlin_w8a8_fp8_impl(input, b_qweight, output, a_scale, b_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_pad, top_k, mode, delta, config_m, false); } torch::Tensor moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t mode, int64_t delta, int64_t config_m ) { return moe_c_moe_gemm_marlin_w8a8_fp8_impl(input, b_qweight, output, a_scale, b_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_pad, top_k, mode, delta, config_m, true); } // 模板抽象 // BLOCK_MNK: [16, 128, 128] // WARP_MNK: [16, 32, 64] // MMA_MNK: [16, 16, 32] // BLOCK_MNK / WARP_MNK = [1, 4, 2] 代表warp在MN方向的排布 block_k/warp_k代表stage // WARP_MNK / MMA_MNK = [1, 2, 2] 代表warp在MNK方向重复计算的次数 会分配额外的寄存器 torch::Tensor moe_c_moe_gemm_marlin_w4a8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta, int64_t config_m ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("Path1 bfloat version decode gemm1 No matching kernel configuration found, using default settings mode %d \n",mode); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8_gemm1n256.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8_gemm1n256.end()) { it->second(params_in); } else { printf("Path1 bfloat version prefill gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false /* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("Path1 bfloat version decode gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("Path1 bfloat version prefill gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 */){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("Path1 half version decode gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8_gemm1n256.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8_gemm1n256.end()) { it->second(params_in); } else { printf("Path1 half version prefill gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("Path1 half version decode gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("Path1 half version prefill gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } } else{ if (output.scalar_type() == at::ScalarType::BFloat16){ if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(config_m <= 512){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("Path2 bfloat version decode gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8.end()) { it->second(params_in); } else { printf("Path2 bfloat version prefill gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if( config_m <= 512 ){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("Path2 bfloat version decode gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("Path2 bfloat version prefill gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 */){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("Path2 half version decode gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8.end()) { it->second(params_in); } else { printf("Path2 half version prefill gemm1 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("Path2 half version decode gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("Path2 half version prefill gemm2 No matching kernel configuration found, using default settings mode %d size_m %d\n",mode,size_m); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } } return output; } torch::Tensor moe_c_moe_gemm_marlin_w4a16(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor b_scale, torch::Tensor b_zeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"<(); // 单行printf打印所有步长变量,带标签便于识别 // printf("stride_asm: %d, stride_ask: %d, stride_bse: %d, stride_bsn: %d, stride_bsk: %d\n", // stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk); constexpr int GROUP_N = 1; constexpr int GROUP_K = 1; bool is_marlin = true; // weight为[E, N, K]时 代表不进行重排 bool first_stage = true; torch::Tensor output_alias = output.alias(); //printf("size_n: %d stride_bse: %d stride_bsn: %d stride_bsk: %d\n", size_n, stride_bse, stride_bsn, stride_bsk); const float* topk_weights_ptr; // 第一阶段这里为null if (topk_weights.has_value()){ topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); first_stage = false; } int num_pad = 0; if (input.scalar_type() == at::ScalarType::Half){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm1_prefill.find(mode); // if (it != kernel_maps_gemm1_prefill.end()) { // it->second(params_in); // } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_decode_w4a16.find(mode); // printf() if (it != kernel_maps_gemm1_decode_w4a16.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<16, 64, 256, 16, 32, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm2_prefill.find(mode); // if (it != kernel_maps_gemm2_prefill.end()) { // // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // it->second(params_in); // } else { // printf("gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else { // // std::cout<<"mode: "<.find(mode); if (it != kernel_maps_gemm2_decode_w4a16.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // // printf("gemm2 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 128, 16, 64, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else if (input.scalar_type() == at::ScalarType::BFloat16){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16<__hip_bfloat16> params_in( (__hip_bfloat16*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (__hip_bfloat16*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (__hip_bfloat16*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm1_prefill.find(mode); // if (it != kernel_maps_gemm1_prefill.end()) { // it->second(params_in); // } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_decode_w4a16<__hip_bfloat16>.find(mode); // printf() if (it != kernel_maps_gemm1_decode_w4a16<__hip_bfloat16>.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<16, 64, 256, 16, 32, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16<__hip_bfloat16> params_in( (__hip_bfloat16*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (__hip_bfloat16*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (__hip_bfloat16*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm2_prefill.find(mode); // if (it != kernel_maps_gemm2_prefill.end()) { // // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // it->second(params_in); // } else { // printf("gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else { // // std::cout<<"mode: "<.find(mode); if (it != kernel_maps_gemm2_decode_w4a16<__hip_bfloat16>.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // // printf("gemm2 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 128, 16, 64, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } return output; } torch::Tensor moe_c_moe_gemm_marlin_w8a16(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"<(); // 单行printf打印所有步长变量,带标签便于识别 // printf("stride_asm: %d, stride_ask: %d, stride_bse: %d, stride_bsn: %d, stride_bsk: %d\n", // stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk); constexpr int GROUP_N = 1; constexpr int GROUP_K = 1; bool is_marlin = true; // weight为[E, N, K]时 代表不进行重排 bool first_stage = true; torch::Tensor output_alias = output.alias(); //printf("size_n: %d stride_bse: %d stride_bsn: %d stride_bsk: %d\n", size_n, stride_bse, stride_bsn, stride_bsk); const float* topk_weights_ptr = nullptr; // 第一阶段这里为null if (topk_weights.has_value()){ topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); first_stage = false; } int num_pad = 0; // #if (DEBUG_W8A8_PERCHANNEL) if (input.scalar_type() == at::ScalarType::Half){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w8a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm1_prefill.find(mode); // if (it != kernel_maps_gemm1_prefill.end()) { // it->second(params_in); // } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_decode_w8a16.find(mode); // printf() if (it != kernel_maps_gemm1_decode_w8a16.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 // std::cout << "8888kernel 1 time-----------------: " << milliseconds << " ms" << std::endl; cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w8a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<16, 64, 256, 16, 32, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w8a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm2_prefill.find(mode); // if (it != kernel_maps_gemm2_prefill.end()) { // // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // it->second(params_in); // } else { // printf("gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else { // // std::cout<<"mode: "<.find(mode); if (it != kernel_maps_gemm2_decode_w8a16.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 // std::cout << "------------kernel 1 time g2-----------------: " << milliseconds << " ms" << std::endl; cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w8a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // // printf("gemm2 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 128, 16, 64, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else if (input.scalar_type() == at::ScalarType::BFloat16){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w8a16<__hip_bfloat16> params_in( (__hip_bfloat16*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (__hip_bfloat16*)output_alias.data_ptr(), (__hip_bfloat16*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, EM, top_k, delta, is_marlin ); if(mode < 500){ if(first_stage){ auto it = kernel_maps_gemm1_decode_w8a16<__hip_bfloat16>.find(mode); if (it != kernel_maps_gemm1_decode_w8a16<__hip_bfloat16>.end() ) { it->second(params_in); } } else { auto it = kernel_maps_gemm2_decode_w8a16<__hip_bfloat16>.find(mode); if (it != kernel_maps_gemm2_decode_w8a16<__hip_bfloat16>.end() ) { it->second(params_in); } } } } else { TORCH_CHECK(false, "moe_w8a16_gemm only supports float16 and bfloat16 input"); } // #endif return output; } torch::Tensor moe_c_moe_w8a8_gemm_block_wise(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise_kernel2( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kerne2_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise_fp8( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kBFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise_kernel2_fp8( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kerne2_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kBFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit) { // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); /*经验值4-8个block,lds为64k,左矩阵BM*BK*2 范围为8k-16k, 所以BM*BN应在4k-8k*/ // std::cout<<"size_m IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); // std::cout<<"size_m is : "< output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } // hipDeviceSynchronize(); // hipEventRecord(stop, 0 ); // hipEventSynchronize( stop ); // float ave_time; // hipEventElapsedTime( &ave_time,start, stop ); // printf( "Time to generate: %9f ms\n", ave_time ); // std::cout<<"input scalar type is :"<( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( half_t*) d_w_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // // run_moe_wna16_gemm_blockwise_( // // (const half*)input.data_ptr(), // // // (const half*)d_input, // // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // // (float*)output_fp32->data_ptr(), // // (const uint32_t*)b_qweight.data_ptr(), // // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out, /*for debug*/ // // (const half*)b_scales.data_ptr(), // // // (const half*)d_scale, // // b_qzeros_ptr, // // // (const uint32_t*)d_scale, // // topk_weights_ptr, // // sorted_token_ids.data_ptr(), // // expert_ids.data_ptr(), // // num_tokens_post_pad.data_ptr(), // // // num_tokens_post_pad_value(), // // // num_tokens_post_pad_data_ptr[0], // // num_token_blocks, // // size_m, // // size_n, // // size_k // // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } // hipEvent_t start, stop; // hipEventCreate(&start); // hipEventCreate(&stop); // hipEventRecord(start, 0); if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } // hipDeviceSynchronize(); // hipEventRecord(stop, 0 ); // hipEventSynchronize( stop ); // float ave_time; // hipEventElapsedTime( &ave_time,start, stop ); // printf( "Time to generate: %9f ms\n", ave_time ); return output; } torch::Tensor moe_c_moe_w8a16_gemm_block_wise(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit) { // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); // const int group_size = size_k / b_scales.size(2); /*经验值4-8个block,lds为64k,左矩阵BM*BK*2 范围为8k-16k, 所以BM*BN应在4k-8k*/ int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_block_wise( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( half_t*) d_w_out, /*for debug*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a16_gemm_block_wise only supports float16"); // } // // half_t* tmp = reinterpret_cast(output.data_ptr()); // // half_t* host_tmp = new half[1]; // 仅拷贝第一个值 // // hipMemcpy(host_tmp, tmp, sizeof(half_t), hipMemcpyDeviceToHost); // // float first_value = __half2float(host_tmp[0]); // // std::cout << "first value: " << first_value << std::endl; // // delete[] host_tmp; // if (use_atomic){ // // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); // } return output; } torch::Tensor moe_c_moe_wna16_gemm_base(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); // BLOCK_SIZE_K = std::min(group_size, BLOCK_SIZE_K); BLOCK_SIZE_K = 64; BLOCK_SIZE_N = 256; //std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); torch::Tensor output_fp32 = torch::empty(output.sizes(),output.options().dtype(torch::kFloat32)); // if (input.scalar_type() == at::ScalarType::Half) { // half_t * d_w_out =nullptr; // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_M, BLOCK_SIZE_M_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_SWITCH(group_size, GROUP_SIZE, [&]{ // // run_moe_wna16_gemm( // run_moe_wna16_gemm_base( // (const half*)input.data_ptr(), // (float*)output_fp32.data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // (const half*)b_scales.data_ptr(), // b_qzeros_ptr, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // num_token_blocks, // size_m, // size_n, // size_k // ); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // __hip_bfloat16 * d_w_out =nullptr; // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_M, BLOCK_SIZE_M_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_SWITCH(group_size, GROUP_SIZE, [&]{ // // run_moe_wna16_gemm( // run_moe_wna16_gemm_base<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, 256, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE>( // (const __hip_bfloat16*)input.data_ptr(), // (float*)output_fp32.data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // (const __hip_bfloat16*)b_scales.data_ptr(), // b_qzeros_ptr, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // num_token_blocks, // size_m, // size_n, // size_k // ); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_wna16_gemm not supports"); // } if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32.to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32.to(torch::kBFloat16)); // 转换为 BF16 } return output; } torch::Tensor moe_c_moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(3); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out_half, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, BLOCK_SIZE_N_, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE_K, BLOCK_SIZE_M_LOOPS, BLOCK_SIZE_N_LOOPS, BLOCK_SIZE_K_LOOPS, USE_ATOMIC,256>( // (const __hip_bfloat16*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( __hip_bfloat16*) d_w_out_bf, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const __hip_bfloat16*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32->to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32->to(torch::kBFloat16)); // 转换为 BF16 } } return output; } torch::Tensor moe_c_moe_wna16_gemm_2(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_2( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out_half, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_2<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, BLOCK_SIZE_N_, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE_K, BLOCK_SIZE_M_LOOPS, BLOCK_SIZE_N_LOOPS, BLOCK_SIZE_K_LOOPS, USE_ATOMIC,256>( // (const __hip_bfloat16*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( __hip_bfloat16*) d_w_out_bf, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const __hip_bfloat16*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32->to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32->to(torch::kBFloat16)); // 转换为 BF16 } } return output; } void moe_c_silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] int64_t rows_per_block = 1, int64_t vec_size = 2) { moe_c::silu_and_mul(out, input, static_cast(rows_per_block), static_cast(vec_size)); }