// #include // #include // #include // #include // #include // #include // #include "moe_wna16_utils.h" // #include "moe_ops.h" #pragma once #include "moe_w8a16_block_wise.h" #include "moe_w8a16_awq.h" #include "moe_w4a16.h" #include "moe_w4a16_2.h" #include "moe_w4a16_base.h" #include "moe_w8a8_block_wise.h" #include "moe_w8a8_block_wise_kernel2.h" #include "moe_w8a8_block_wise_fp8.h" #include "moe_w8a8_block_wise_kernel2_fp8.h" #include "topk_softmax_kernel.h" #include "moe_align_sum_kernels.h" #include #include #include #include #include // #include "moe_w8a8_utils.h" // #include "moe_w8a8_config.h" #include "moe_w8a8_opt.h" #include "moe_w4a16_opt.h" #include "moe_w4a8_opt.h" #define BIT_SWITCH(bit, BIT, ...) \ [&] { \ if (bit == 8) { \ constexpr static int BIT = 8; \ return __VA_ARGS__(); \ }else if (bit == 4) { \ constexpr static int BIT = 4; \ return __VA_ARGS__(); \ } \ else { \ std::cout<<"unsupported BIT"< topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512){ auto it = kernel_maps_gemm1_decode.find(mode); if (it != kernel_maps_gemm1_decode.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512 * real_topk){ auto it = kernel_maps_gemm2_decode.find(mode); if (it != kernel_maps_gemm2_decode.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512){ auto it = kernel_maps_gemm1_decode.find(mode); if (it != kernel_maps_gemm1_decode.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found \n"); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512 * real_topk){ auto it = kernel_maps_gemm2_decode.find(mode); if (it != kernel_maps_gemm2_decode.end()) { it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } return output; } // 模板抽象 // BLOCK_MNK: [16, 128, 128] // WARP_MNK: [16, 32, 64] // MMA_MNK: [16, 16, 32] // BLOCK_MNK / WARP_MNK = [1, 4, 2] 代表warp在MN方向的排布 block_k/warp_k代表stage // WARP_MNK / MMA_MNK = [1, 2, 2] 代表warp在MNK方向重复计算的次数 会分配额外的寄存器 torch::Tensor moe_c_moe_gemm_marlin_w8a8_fp8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512){ auto it = kernel_maps_gemm1_decode_fp8.find(mode); if (it != kernel_maps_gemm1_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_fp8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512 * real_topk){ auto it = kernel_maps_gemm2_decode_fp8.find(mode); if (it != kernel_maps_gemm2_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_fp8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Float8_e4m3fn){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512){ auto it = kernel_maps_gemm1_decode_fp8.find(mode); if (it != kernel_maps_gemm1_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_fp8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(size_m <= 512 * real_topk){ auto it = kernel_maps_gemm2_decode_fp8.find(mode); if (it != kernel_maps_gemm2_decode_fp8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_fp8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } return output; } // 模板抽象 // BLOCK_MNK: [16, 128, 128] // WARP_MNK: [16, 32, 64] // MMA_MNK: [16, 16, 32] // BLOCK_MNK / WARP_MNK = [1, 4, 2] 代表warp在MN方向的排布 block_k/warp_k代表stage // WARP_MNK / MMA_MNK = [1, 2, 2] 代表warp在MNK方向重复计算的次数 会分配额外的寄存器 torch::Tensor moe_c_moe_gemm_marlin_w4a8(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor a_scale, torch::Tensor b_scale, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"< params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8_gemm1n256.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8_gemm1n256.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false /* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 */){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8_gemm1n256.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8_gemm1n256.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found \n"); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } } else{ if (output.scalar_type() == at::ScalarType::BFloat16){ if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8.end()) { it->second(params_in); } else { printf("bfloat version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<32, 256, 64, 32, 128, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<64, 64, 128, 64, 64, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<48, 16, 256, 48, 16, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<32, 64, 128, 32, 64, 64, 1>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (bhalf_t*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false /* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"mode: "<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("bfloat version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 1024, 64, 16, 128, 64, 1>(params_in); // launch_moe_w8a8_second_stage_decode<32, 128, 64, 32, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<64, 64, 64, 64, 32, 64, 2>(params_in); // launch_moe_w8a8_second_stage_decode<32, 256, 128, 32, 128, 128, 1>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { // TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } else if (output.scalar_type() == at::ScalarType::Half) { // printf("********************************************halfhalfhalf"); if (input.scalar_type() == at::ScalarType::Char){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 */){ auto it = kernel_maps_gemm1_decode_w4a8.find(mode); if (it != kernel_maps_gemm1_decode_w4a8.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_prefill_w4a8.find(mode); // printf() if ( it != kernel_maps_gemm1_prefill_w4a8.end()) { it->second(params_in); } else { printf("half version gemm1 No matching kernel configuration found \n"); // launch_moe_w8a8_first_stage_decode<16, 16, 512, 16, 16, 128, 4>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id // 使用int8类型处理 GemmParams_w4a8 params_in( (const char*)input.data_ptr(), (const char*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), (float*)a_scale.data_ptr(), (float*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), num_tokens_post_pad.data_ptr(), size_m, stride_bse, size_k, stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk, EM, top_k, real_topk, is_marlin ); if(false/* size_m <= 512 * real_topk */){ auto it = kernel_maps_gemm2_decode_w4a8.find(mode); if (it != kernel_maps_gemm2_decode_w4a8.end()) { it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); } }else { // std::cout<<"*************************float16"<.find(mode); if ( it != kernel_maps_gemm2_prefill_w4a8.end()) { // printf("**********************************%d",mode); it->second(params_in); } else { printf("half version gemm2 No matching kernel configuration found \n"); // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } } } return output; } torch::Tensor moe_c_moe_gemm_marlin_w4a16(torch::Tensor input, torch::Tensor b_qweight, torch::Tensor output, torch::Tensor b_scale, torch::Tensor b_zeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, // gemm1为topk gemm2为1 因为gemm1输入为[m, k] gemm2输入为[m*topk, k] int64_t mode, int64_t delta ) { const int size_m = input.size(0); const int EXPERTS = b_qweight.size(0); const int size_k = input.size(1); // std::cout<<"size_k"<(); // 单行printf打印所有步长变量,带标签便于识别 // printf("stride_asm: %d, stride_ask: %d, stride_bse: %d, stride_bsn: %d, stride_bsk: %d\n", // stride_asm, stride_ask, stride_bse, stride_bsn, stride_bsk); constexpr int GROUP_N = 1; constexpr int GROUP_K = 1; bool is_marlin = true; // weight为[E, N, K]时 代表不进行重排 bool first_stage = true; torch::Tensor output_alias = output.alias(); //printf("size_n: %d stride_bse: %d stride_bsn: %d stride_bsk: %d\n", size_n, stride_bse, stride_bsn, stride_bsk); const float* topk_weights_ptr; // 第一阶段这里为null if (topk_weights.has_value()){ topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); first_stage = false; } int num_pad = 0; if (input.scalar_type() == at::ScalarType::Half){ if(first_stage){ int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm1_prefill.find(mode); // if (it != kernel_maps_gemm1_prefill.end()) { // it->second(params_in); // } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_first_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else{ //decode // std::cout<<"***************************************decode \n" << mode; auto it = kernel_maps_gemm1_decode_w4a16.find(mode); // printf() if (it != kernel_maps_gemm1_decode_w4a16.end() ) { float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } it->second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // printf("gemm1 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_first_stage_decode<16, 64, 256, 16, 32, 128, 2>(params_in); // launch_moe_w8a8_first_stage_decode<16, 32, 64, 16, 16, 64, 2>(params_in); } } }else{ //gemm2 int64_t EM = sorted_token_ids.size(0); // 一维线性化的token id GemmParams_w4a16 params_in( (half*)input.data_ptr(), (uint32_t*)b_qweight.data_ptr(), (half*)output_alias.data_ptr(), reinterpret_cast(b_zeros.data_ptr()), (half*)b_scale.data_ptr(), topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_pad, //num_tokens_post_pad[0].item(), //这里获取值 会造成device->host的拷贝和一部分空泡 num_tokens_post_pad.data_ptr(), size_m, size_n, size_k, // stride_asm, // stride_ask, // stride_bse, // stride_bsn, // stride_bsk, EM, top_k, delta, is_marlin ); if(mode >= 500){ // auto it = kernel_maps_gemm2_prefill.find(mode); // if (it != kernel_maps_gemm2_prefill.end()) { // // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // it->second(params_in); // } else { // printf("gemm2 No matching kernel configuration found, using default settings \n"); // launch_moe_w8a8_second_stage_prefill<16, 64, 128, 16, 16, 64>(params_in); // } }else { // // std::cout<<"mode: "<second(params_in); if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } } else { // // printf("gemm2 No matching kernel configuration found, using default settings \n"); // // launch_moe_w8a8_second_stage_decode<16, 64, 128, 16, 16, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 128, 16, 64, 128, 2>(params_in); // launch_moe_w8a8_second_stage_decode<16, 256, 64, 16, 32, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 64, 128, 16, 16, 64, 2>(params_in); // // launch_moe_w8a8_first_stage_decode<16, 128, 128, 16, 16, 64, 2>(params_in); } } // hipDeviceSynchronize(); } } else { TORCH_CHECK(false, "moe_w8a8_gemm only supports int8"); } return output; } torch::Tensor moe_c_moe_w8a8_gemm_block_wise(torch::Tensor input, torch::Tensor a_scales,torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_w8a8_gemm_block_wise_kernel2( // (const uint32_t*)input.data_ptr(), // // (const half*)d_input, // (const float*)a_scales.data_ptr(), // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kerne2_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { BIT_SWITCH(bit, BIT, [&]{ TOPK_SWITCH(top_k, TOPK, [&]{ BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ run_moe_w8a8_gemm_block_wise_fp8( (const uint32_t*)input.data_ptr(), // (const half*)d_input, (const float*)a_scales.data_ptr(), use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // (float*)output_fp32->data_ptr(), (const uint32_t*)b_qweight.data_ptr(), // (const uint32_t*)d_w_test, ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // (const half*)b_scales.data_ptr(), (const float*)b_scales.data_ptr(), // (const half*)d_scale, b_qzeros_ptr, // (const uint32_t*)d_scale, topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), // num_tokens_post_pad_value(), // num_tokens_post_pad_data_ptr[0], num_token_blocks, size_m, size_n, size_k ); // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma }); }); }); }); }); }); }); }); }); }); }); }); } else { TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kBFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); int64_t BLOCK_SIZE_N = 64; // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } if (true/* input.scalar_type() == at::ScalarType::QInt8 */) { BIT_SWITCH(bit, BIT, [&]{ TOPK_SWITCH(top_k, TOPK, [&]{ BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ run_moe_w8a8_gemm_block_wise_kernel2_fp8( (const uint32_t*)input.data_ptr(), // (const half*)d_input, (const float*)a_scales.data_ptr(), use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // (float*)output_fp32->data_ptr(), (const uint32_t*)b_qweight.data_ptr(), // (const uint32_t*)d_w_test, ( int*)&d_w_out[0], /*for debug*/ /*使用时需修改为device端地址,这里仅为占位使用*/ // (const half*)b_scales.data_ptr(), (const float*)b_scales.data_ptr(), // (const half*)d_scale, b_qzeros_ptr, // (const uint32_t*)d_scale, topk_weights_ptr, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), // num_tokens_post_pad_value(), // num_tokens_post_pad_data_ptr[0], num_token_blocks, size_m, size_n, size_k ); // printf("run_moe_w8a8_gemm_block_wise\n"); // kernel-1 mma }); }); }); }); }); }); }); }); }); }); }); }); } else { TORCH_CHECK(false, "moe_w8a8_gemm_block_wise only supports int8_t"); } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ std::ofstream ofs("./w8a8_kerne2_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kBFloat16)); } cudaEventDestroy(start); cudaEventDestroy(stop); // {//for debug // hipDeviceSynchronize(); // hipMemcpy(&d_w_out[0], dev_d_w, 16*64 * sizeof(int), hipMemcpyDeviceToHost); // for(int i =0;i<16;i++){ // for(int j = 0;j<64 ;j++){ // std::cout<<(int)(d_w_out[i*64+j])<<"\t"; // } // std::cout<<"|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"< b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit) { // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); /*经验值4-8个block,lds为64k,左矩阵BM*BK*2 范围为8k-16k, 所以BM*BN应在4k-8k*/ // std::cout<<"size_m IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); // std::cout<<"size_m is : "< output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } // hipDeviceSynchronize(); // hipEventRecord(stop, 0 ); // hipEventSynchronize( stop ); // float ave_time; // hipEventElapsedTime( &ave_time,start, stop ); // printf( "Time to generate: %9f ms\n", ave_time ); // std::cout<<"input scalar type is :"<( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( half_t*) d_w_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // // run_moe_wna16_gemm_blockwise_( // // (const half*)input.data_ptr(), // // // (const half*)d_input, // // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // // (float*)output_fp32->data_ptr(), // // (const uint32_t*)b_qweight.data_ptr(), // // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out, /*for debug*/ // // (const half*)b_scales.data_ptr(), // // // (const half*)d_scale, // // b_qzeros_ptr, // // // (const uint32_t*)d_scale, // // topk_weights_ptr, // // sorted_token_ids.data_ptr(), // // expert_ids.data_ptr(), // // num_tokens_post_pad.data_ptr(), // // // num_tokens_post_pad_value(), // // // num_tokens_post_pad_data_ptr[0], // // num_token_blocks, // // size_m, // // size_n, // // size_k // // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } // hipEvent_t start, stop; // hipEventCreate(&start); // hipEventCreate(&stop); // hipEventRecord(start, 0); if (use_atomic){ // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); } // hipDeviceSynchronize(); // hipEventRecord(stop, 0 ); // hipEventSynchronize( stop ); // float ave_time; // hipEventElapsedTime( &ave_time,start, stop ); // printf( "Time to generate: %9f ms\n", ave_time ); return output; } torch::Tensor moe_c_moe_w8a16_gemm_block_wise(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t group_size_n, int64_t group_size_k, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t bit) { // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); // const int group_size = size_k / b_scales.size(2); /*经验值4-8个block,lds为64k,左矩阵BM*BK*2 范围为8k-16k, 所以BM*BN应在4k-8k*/ int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size_k; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size_k == 0, "BLOCK_SIZE_K must divisible by group_size_k"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_N_SWITCH(group_size_n, GROUP_SIZE_N, [&]{ // GROUP_SIZE_K_SWITCH(group_size_k, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_block_wise( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // ( half_t*) d_w_out, /*for debug*/ // // (const half*)b_scales.data_ptr(), // (const float*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } else { // TORCH_CHECK(false, "moe_w8a16_gemm_block_wise only supports float16"); // } // // half_t* tmp = reinterpret_cast(output.data_ptr()); // // half_t* host_tmp = new half[1]; // 仅拷贝第一个值 // // hipMemcpy(host_tmp, tmp, sizeof(half_t), hipMemcpyDeviceToHost); // // float first_value = __half2float(host_tmp[0]); // // std::cout << "first value: " << first_value << std::endl; // // delete[] host_tmp; // if (use_atomic){ // // std::cout<<" fp32 convert to fp16 "<to(torch::kFloat16)); // } return output; } torch::Tensor moe_c_moe_wna16_gemm_base(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); // BLOCK_SIZE_K = std::min(group_size, BLOCK_SIZE_K); BLOCK_SIZE_K = 64; BLOCK_SIZE_N = 256; //std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); torch::Tensor output_fp32 = torch::empty(output.sizes(),output.options().dtype(torch::kFloat32)); // if (input.scalar_type() == at::ScalarType::Half) { // half_t * d_w_out =nullptr; // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_M, BLOCK_SIZE_M_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_SWITCH(group_size, GROUP_SIZE, [&]{ // // run_moe_wna16_gemm( // run_moe_wna16_gemm_base( // (const half*)input.data_ptr(), // (float*)output_fp32.data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // (const half*)b_scales.data_ptr(), // b_qzeros_ptr, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // num_token_blocks, // size_m, // size_n, // size_k // ); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // __hip_bfloat16 * d_w_out =nullptr; // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_M, BLOCK_SIZE_M_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_SWITCH(group_size, GROUP_SIZE, [&]{ // // run_moe_wna16_gemm( // run_moe_wna16_gemm_base<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, 256, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE>( // (const __hip_bfloat16*)input.data_ptr(), // (float*)output_fp32.data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // (const __hip_bfloat16*)b_scales.data_ptr(), // b_qzeros_ptr, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // num_token_blocks, // size_m, // size_n, // size_k // ); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_wna16_gemm not supports"); // } if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32.to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32.to(torch::kBFloat16)); // 转换为 BF16 } return output; } torch::Tensor moe_c_moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(3); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out_half, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, BLOCK_SIZE_N_, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE_K, BLOCK_SIZE_M_LOOPS, BLOCK_SIZE_N_LOOPS, BLOCK_SIZE_K_LOOPS, USE_ATOMIC,256>( // (const __hip_bfloat16*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( __hip_bfloat16*) d_w_out_bf, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const __hip_bfloat16*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_1_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32->to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32->to(torch::kBFloat16)); // 转换为 BF16 } } return output; } torch::Tensor moe_c_moe_wna16_gemm_2(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_m, int64_t BLOCK_SIZE_n, int64_t BLOCK_SIZE_k, int64_t kloops, int64_t nloops, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int size_m = input.size(0); const int size_n = b_qweight.size(1); const int size_k = input.size(1); const int group_size = size_k / b_scales.size(2); int64_t BLOCK_SIZE_N = std::min(64, size_n); // std::cout<<"BLOCK_SIZE_N IS : "<(); const float* topk_weights_ptr; if (topk_weights.has_value()) topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, "BLOCK_SIZE_K must divisible by group_size"); TORCH_CHECK(BLOCK_SIZE_m <= 64, "BLOCK_SIZE_m must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); bool use_atomic = (size_k != BLOCK_SIZE_K*block_size_k_loops); std::optional output_fp32; if (use_atomic){ output_fp32 = torch::zeros(output.sizes(),output.options().dtype(torch::kFloat32)); // output_fp32->zero_(); } float milliseconds = 0; cudaEvent_t start, stop; const char* find_best = std::getenv("WHICH_TO_TEST"); if (find_best) { cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); // 记录开始 } // if (input.scalar_type() == at::ScalarType::Half) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_2( // (const half*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( half_t*) d_w_out_half, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const half*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else if (input.scalar_type() == at::ScalarType::BFloat16) { // BIT_SWITCH(bit, BIT, [&]{ // TOPK_SWITCH(top_k, TOPK, [&]{ // BLOCK_M_SWITCH(BLOCK_SIZE_m, BLOCK_SIZE_M_, [&]{ // BLOCK_N_SWITCH(BLOCK_SIZE_N, BLOCK_SIZE_N_, [&]{ // BLOCK_K_SWITCH(BLOCK_SIZE_K, BLOCK_SIZE_K_, [&]{ // // BOOL_SWITCH(b_qzeros.has_value(), has_zp, [&]{ // BOOL_SWITCH(topk_weights.has_value(), mul_topk_weight, [&]{ // GROUP_SIZE_K_SWITCH(group_size, GROUP_SIZE_K, [&]{ // BLOCK_SIZE_M_LOOPS_SWITCH(block_size_m_loops , BLOCK_SIZE_M_LOOPS, [&]{ // BLOCK_SIZE_N_LOOPS_SWITCH(block_size_n_loops , BLOCK_SIZE_N_LOOPS, [&]{ // BLOCK_SIZE_K_LOOPS_SWITCH(block_size_k_loops , BLOCK_SIZE_K_LOOPS, [&]{ // BOOL_SWITCH(use_atomic , USE_ATOMIC, [&]{ // run_moe_wna16_gemm_2<__hip_bfloat16, 4, TOPK, BLOCK_SIZE_M_, BLOCK_SIZE_N_, BLOCK_SIZE_K_, true, mul_topk_weight, GROUP_SIZE_K, BLOCK_SIZE_M_LOOPS, BLOCK_SIZE_N_LOOPS, BLOCK_SIZE_K_LOOPS, USE_ATOMIC,256>( // (const __hip_bfloat16*)input.data_ptr(), // // (const half*)d_input, // use_atomic ?(float*)output_fp32->data_ptr():(float*)output.data_ptr(), // // (float*)output_fp32->data_ptr(), // (const uint32_t*)b_qweight.data_ptr(), // // (const uint32_t*)d_w_test, // // ( __hip_bfloat16*) d_w_out_bf, /*for debug*/ // // ( float*) float_d_out, /*for debug*/ // (const __hip_bfloat16*)b_scales.data_ptr(), // // (const half*)d_scale, // b_qzeros_ptr, // // (const uint32_t*)d_scale, // topk_weights_ptr, // sorted_token_ids.data_ptr(), // expert_ids.data_ptr(), // num_tokens_post_pad.data_ptr(), // // num_tokens_post_pad_value(), // // num_tokens_post_pad_data_ptr[0], // num_token_blocks, // size_m, // size_n, // size_k // ); // kernel-1 mma // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // }); // } // else { // TORCH_CHECK(false, "moe_w8a16_gemm_awq only supports float16"); // } if (find_best) { cudaEventRecord(stop); // 记录结束 cudaEventSynchronize(stop); // 等待 kernel 执行完成 cudaEventElapsedTime(&milliseconds, start, stop); // 计算时间 /* std::cout << "kernel 1 time: " << milliseconds << " ms" << std::endl; */ cudaEventDestroy(start); cudaEventDestroy(stop); std::ofstream ofs("./w4a16_kernel_2_timecost", std::ios::app); // 追加写入 if (ofs.is_open()) { ofs << milliseconds << std::endl; ofs.close(); } } if (use_atomic){ if (input.scalar_type() == at::ScalarType::Half) { output.copy_(output_fp32->to(torch::kFloat16)); } else if (input.scalar_type() == at::ScalarType::BFloat16) { output.copy_(output_fp32->to(torch::kBFloat16)); // 转换为 BF16 } } return output; } void moe_c_silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(moe_c::silu_kernel, true); }