// SPDX-License-Identifier: MIT #include #include #include #include #include #include "aiter_hip_common.h" #include "moe_op.h" #include "py_itfs_common.h" struct __attribute__((packed)) KernelArgs { void *ptr_O; p2 _p0; void *ptr_X; p2 _p1; void *ptr_GU; p2 _p2; void *ptr_XC; p2 _p3; void *ptr_D; p2 _p4; void *ptr_XQ; p2 _p5; void *ptr_GUQ; p2 _p6; void *ptr_DQ; p2 _p7; void *ptr_SMQ; p2 _p8; void *ptr_STP; p2 _p9; void *ptr_SW; p2 _p10; void *ptr_SEP; p2 _p11; unsigned int dim; p3 _p12; unsigned int inter_dim; p3 _p13; unsigned int token_cnt; p3 _p14; unsigned int eprt_cnt; p3 _p15; unsigned int Xs; p3 _p16; unsigned int GUs; p3 _p17; unsigned int Ds; p3 _p18; unsigned int Os; p3 _p19; unsigned int eGUs; p3 _p20; unsigned int eDs; p3 _p21; unsigned int eGUQs; p3 _p22; unsigned int eDQs; p3 _p23; unsigned int eSMQs; p3 _p24; unsigned int topk; p3 _p25; }; class FMoeKernel { private: hipModule_t module; hipFunction_t kernel_func; uint32_t sub_GU = 512; bool is_int4 = false; public: FMoeKernel(const char *name, const char *hsaco, uint32_t sub_GU = 512) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; this->sub_GU = sub_GU; }; void set_int4(bool is_int4_) { is_int4 = is_int4_; } template void launch_kernel(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &w1, // [expert, inter_dim, dim] N,K torch::Tensor &w2, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // std::optional input_dqn = std::nullopt, std::optional w1_dqn = std::nullopt, std::optional w2_dqn = std::nullopt, std::optional w2_smooth_qnt = std::nullopt // ) { int token_cnt = out.size(0); int dim = input.size(1); int sub_X_cnt = sorted_expert_ids.size(0); int eprt = w1.size(0); int inter_dim = is_int4 ? w2.size(2) * 8 : w2.size(2); uint32_t sub_GU = this->sub_GU; uint32_t I_elemSize = sizeof(T); uint32_t O_elemSize = sizeof(T_O); int stride_X = input.stride(0) * input.element_size(); int stride_GU = dim * I_elemSize; int stride_D = inter_dim * I_elemSize; if (is_int4) { stride_GU /= 2; stride_D /= 2; } int stride_expert_GU = stride_GU * inter_dim; int stride_expert_D = stride_D * dim; int stride_expert_GUDQN = w1_dqn.has_value() ? w1_dqn.value().stride(0) * sizeof(float) : 0; int stride_expert_DDQN = w2_dqn.has_value() ? w2_dqn.value().stride(0) * sizeof(float) : 0; int stride_expert_SMTDQN = inter_dim * sizeof(float); int stride_O = dim * O_elemSize; if (inter_dim * 2 == w1.size(1)) { stride_expert_GU *= 2; // stride_expert_GUDQN *= 2; } KernelArgs args; size_t arg_size = sizeof(args); args.ptr_O = out.data_ptr(); args.ptr_X = input.data_ptr(); args.ptr_GU = w1.data_ptr(); args.ptr_XC = num_valid_ids.data_ptr(); args.ptr_D = w2.data_ptr(); if constexpr (std::is_same::value) { args.ptr_XQ = input_dqn.value().data_ptr(); args.ptr_GUQ = w1_dqn.value().data_ptr(); args.ptr_DQ = w2_dqn.value().data_ptr(); args.ptr_SMQ = w2_smooth_qnt.has_value() ? w2_smooth_qnt.value().data_ptr() : nullptr; } else { args.ptr_XQ = nullptr; args.ptr_GUQ = nullptr; args.ptr_DQ = nullptr; args.ptr_SMQ = nullptr; } args.ptr_STP = sorted_token_ids.data_ptr(); args.ptr_SW = sorted_weights.data_ptr(); args.ptr_SEP = sorted_expert_ids.data_ptr(); args.dim = dim; args.inter_dim = inter_dim; args.token_cnt = token_cnt; args.eprt_cnt = eprt; args.Xs = stride_X; args.GUs = stride_GU; args.Ds = stride_D; args.Os = stride_O; args.eGUs = stride_expert_GU; args.eDs = stride_expert_D; args.eGUQs = stride_expert_GUDQN; args.eDQs = stride_expert_DDQN; args.eSMQs = stride_expert_SMTDQN; args.topk = topk; void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &arg_size, HIP_LAUNCH_PARAM_END}; int bdx = 256; int gdx = ((inter_dim + sub_GU - 1) / sub_GU); int gdy = sub_X_cnt; int gdz = 1; // std::cout << "args.dim: " << args.dim << std::endl; // std::cout << "args.inter_dim: " << args.inter_dim << std::endl; // std::cout << "args.token_cnt: " << args.token_cnt << std::endl; // std::cout << "args.eprt_cnt: " << args.eprt_cnt << std::endl; // std::cout << "args.stride_X: " << args.Xs << std::endl; // std::cout << "args.stride_GU: " << args.GUs << std::endl; // std::cout << "args.stride_D: " << args.Ds << std::endl; // std::cout << "args.stride_O: " << args.Os << std::endl; // std::cout << "args.stride_expert_GU: " << args.eGUs << std::endl; // std::cout << "args.stride_expert_D: " << args.eDs << std::endl; // std::cout << "args.stride_expert_GUDQN: " << args.eGUQs << std::endl; // std::cout << "args.stride_expert_DDQN: " << args.eDQs << std::endl; // std::cout << "args.stride_expert_SMTDQN: " << args.eSMQs << std::endl; // std::cout << "args.topk: " << args.topk << std::endl; // std::cout << "gdx: " << gdx << std::endl; // std::cout << "gdy: " << gdy << std::endl; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if constexpr (switchGxy) { HIP_CALL(hipModuleLaunchKernel(kernel_func, gdy, gdx, gdz, bdx, 1, 1, 0, stream, nullptr, (void **)&config)); } else { HIP_CALL(hipModuleLaunchKernel(kernel_func, gdx, gdy, gdz, bdx, 1, 1, 0, stream, nullptr, (void **)&config)); } }; }; int get_heuristic_tile(int inter_dim, int sub_X_cnt, const std::vector &available_tiles) { // int tiles[7] = {512, 448, 384, 320, 256, 192, 128}; hipDevice_t dev; hipDeviceProp_t dev_prop; HIP_CALL(hipGetDevice(&dev)); HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); uint32_t num_cu = dev_prop.multiProcessorCount; uint32_t empty_cu = num_cu; uint32_t tg_num = 0; uint32_t round = 0xffffffff; int selectedTile = 0; for (auto tile : available_tiles) { if ((inter_dim % tile) == 0) { tg_num = inter_dim / tile * sub_X_cnt; uint32_t local_round = (tg_num + num_cu - 1) / num_cu; if (local_round < round) { round = local_round; selectedTile = tile; empty_cu = local_round * num_cu - tg_num; } else if (local_round == round) { if (empty_cu > (local_round * num_cu - tg_num)) { round = local_round; selectedTile = tile; empty_cu = local_round * num_cu - tg_num; } } } } return selectedTile; }; void fmoe(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk // ) { // g1u0 FMoeKernel *impl_ptr = nullptr; if (input.dtype() == at::ScalarType::Half) { static FMoeKernel impl_f16("fmoe_kernel_func", "fmoe_f16.co"); impl_ptr = &impl_f16; } else if (input.dtype() == at::ScalarType::BFloat16) { static FMoeKernel impl_b16("fmoe_kernel_func", "fmoe_b16.co"); impl_ptr = &impl_b16; } TORCH_CHECK(impl_ptr != nullptr, __func__, ": unsupport current input type:", input.scalar_type()); impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk); } void fmoe_int8_g1u0(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc2_smooth_scale, // [expert, 1, inter_dim], ActivationType activation) { FMoeKernel *impl_ptr = nullptr; int inter_dim = down.size(2); static std::unordered_map> impl_ptr_map; struct FMoeKernelConfig { std::string name; std::string co_name; int tile_size; }; if (input.dtype() == at::ScalarType::Char || input.dtype() == at::ScalarType::Byte) { static std::unordered_map gelu_kernel_int8_configs = { {512, {"fmoe_int8_g1u0_subGU_512_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_512_gelu.co", 512}}, {448, {"fmoe_int8_g1u0_subGU_448_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_448_gelu.co", 448}}, {384, {"fmoe_int8_g1u0_subGU_384_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_384_gelu.co", 384}}, {320, {"fmoe_int8_g1u0_subGU_320_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_320_gelu.co", 320}}, {256, {"fmoe_int8_g1u0_subGU_256_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_256_gelu.co", 256}}, {192, {"fmoe_int8_g1u0_subGU_192_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_192_gelu.co", 192}}, {128, {"fmoe_int8_g1u0_subGU_128_gelu", "fmoe/gelu/fmoe_int8_g1u0_subGU_128_gelu.co", 128}}}; static std::unordered_map silu_kernel_int8_configs = { {512, {"fmoe_int8_g1u0_subGU_512", "fmoe/silu/fmoe_int8_g1u0_subGU_512.co", 512}}, {448, {"fmoe_int8_g1u0_subGU_448", "fmoe/silu/fmoe_int8_g1u0_subGU_448.co", 448}}, {384, {"fmoe_int8_g1u0_subGU_384", "fmoe/silu/fmoe_int8_g1u0_subGU_384.co", 384}}, {320, {"fmoe_int8_g1u0_subGU_320", "fmoe/silu/fmoe_int8_g1u0_subGU_320.co", 320}}, {256, {"fmoe_int8_g1u0_subGU_256", "fmoe/silu/fmoe_int8_g1u0_subGU_256.co", 256}}, {192, {"fmoe_int8_g1u0_subGU_192", "fmoe/silu/fmoe_int8_g1u0_subGU_192.co", 192}}, {128, {"fmoe_int8_g1u0_subGU_128", "fmoe/silu/fmoe_int8_g1u0_subGU_128.co", 128}}}; std::unordered_map *config_map = nullptr; if (activation == ActivationType::Gelu) { config_map = &gelu_kernel_int8_configs; } else if (activation == ActivationType::Silu) { config_map = &silu_kernel_int8_configs; } if (!config_map) { TORCH_CHECK(false, __func__, " Input only supput Int8!"); } const int tiles[] = {512, 448, 384, 320, 256, 192, 128}; int selectedTile = 0; for (int tile : tiles) { if (inter_dim % tile == 0) { selectedTile = tile; break; } } if (selectedTile == 0) { TORCH_CHECK(false, __func__, " Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 128, 192, 256, 320, 384, 448 or 512"); } auto it = config_map->find(selectedTile); if (it != config_map->end()) { const auto &config = it->second; const char *name = config.name.c_str(); const char *co_name = config.co_name.c_str(); auto result = impl_ptr_map.emplace(name, nullptr); if (result.second) { result.first->second = std::make_unique(name, co_name, config.tile_size); } impl_ptr = result.first->second.get(); } } impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args input_scale, fc1_scale, fc2_scale, fc2_smooth_scale); } void fmoe_g1u1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim*2, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] std::optional fc2_smooth_scale, // [expert, 1, inter_dim] ActivationType activation) { struct FMoeKernelConfig { std::string name; std::string co_name; int tile_size; }; FMoeKernel *impl_ptr = nullptr; int inter_dim = down.size(2); int sub_X_cnt = sorted_expert_ids.size(0); static std::unordered_map> impl_ptr_map; if (gate.dtype() == at::ScalarType::UInt32 || gate.dtype() == at::ScalarType::Int) { int selectedTile = get_heuristic_tile(inter_dim, sub_X_cnt, {512, 256, 128}); // todo,add tune interface here if (selectedTile == 512) { static FMoeKernel impl_int4_512("fmoe_int4fp8_g1u1_subGU_512_gelu", "fmoe_int4fp8_g1u1_subGU_512_gelu.co", 512); impl_ptr = &impl_int4_512; } else if (selectedTile == 256) { static FMoeKernel impl_int4_256("fmoe_int4fp8_g1u1_subGU_256_gelu", "fmoe_int4fp8_g1u1_subGU_256_gelu.co", 256); impl_ptr = &impl_int4_256; } else if (selectedTile == 128) { static FMoeKernel impl_int4_128("fmoe_int4fp8_g1u1_subGU_128_gelu", "fmoe_int4fp8_g1u1_subGU_128_gelu.co", 128); impl_ptr = &impl_int4_128; } else { TORCH_CHECK(false, __func__, " Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 128, 256, or 512"); } impl_ptr->set_int4(true); } else if (input.dtype() == at::ScalarType::Char || input.dtype() == at::ScalarType::Byte) { static std::unordered_map multix_kernel_int8_configs = { {512, {"fmoe_int8_g1u1_multix_subGU_512", "fmoe_int8_g1u1_multix_subGU_512.co", 512}}, {448, {"fmoe_int8_g1u1_multix_subGU_448", "fmoe_int8_g1u1_multix_subGU_448.co", 448}}, {384, {"fmoe_int8_g1u1_multix_subGU_384", "fmoe_int8_g1u1_multix_subGU_384.co", 384}}, {320, {"fmoe_int8_g1u1_multix_subGU_320", "fmoe_int8_g1u1_multix_subGU_320.co", 320}}, {256, {"fmoe_int8_g1u1_multix_subGU_256", "fmoe_int8_g1u1_multix_subGU_256.co", 256}}, {192, {"fmoe_int8_g1u1_multix_subGU_192", "fmoe_int8_g1u1_multix_subGU_192.co", 192}}, {128, {"fmoe_int8_g1u1_multix_subGU_128", "fmoe_int8_g1u1_multix_subGU_128.co", 128}}}; static std::unordered_map silu_kernel_int8_configs = { {512, {"fmoe_int8_g1u1_subGU_512", "fmoe/silu/fmoe_int8_g1u1_subGU_512.co", 512}}, {448, {"fmoe_int8_g1u1_subGU_448", "fmoe/silu/fmoe_int8_g1u1_subGU_448.co", 448}}, {384, {"fmoe_int8_g1u1_subGU_384", "fmoe/silu/fmoe_int8_g1u1_subGU_384.co", 384}}, {320, {"fmoe_int8_g1u1_subGU_320", "fmoe/silu/fmoe_int8_g1u1_subGU_320.co", 320}}, {256, {"fmoe_int8_g1u1_subGU_256", "fmoe/silu/fmoe_int8_g1u1_subGU_256.co", 256}}, {192, {"fmoe_int8_g1u1_subGU_192", "fmoe/silu/fmoe_int8_g1u1_subGU_192.co", 192}}, {128, {"fmoe_int8_g1u1_subGU_128", "fmoe/silu/fmoe_int8_g1u1_subGU_128.co", 128}}}; static std::unordered_map gelu_kernel_int8_configs = { {512, {"fmoe_int8_g1u1_subGU_512_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_512_gelu.co", 512}}, {448, {"fmoe_int8_g1u1_subGU_448_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_448_gelu.co", 448}}, {384, {"fmoe_int8_g1u1_subGU_384_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_384_gelu.co", 384}}, {320, {"fmoe_int8_g1u1_subGU_320_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_320_gelu.co", 320}}, {256, {"fmoe_int8_g1u1_subGU_256_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_256_gelu.co", 256}}, {192, {"fmoe_int8_g1u1_subGU_192_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_192_gelu.co", 192}}, {128, {"fmoe_int8_g1u1_subGU_128_gelu", "fmoe/gelu/fmoe_int8_g1u1_subGU_128_gelu.co", 128}}}; int selectedTile = get_heuristic_tile(inter_dim, sub_X_cnt, {512, 448, 384, 320, 256, 192, 128}); // todo,add tune interface here std::unordered_map *config_map = nullptr; if (fc2_smooth_scale.has_value()) { config_map = &multix_kernel_int8_configs; } else if (activation == ActivationType::Gelu) { config_map = &gelu_kernel_int8_configs; } else if (activation == ActivationType::Silu) { config_map = &silu_kernel_int8_configs; } if (config_map) { auto it = config_map->find(selectedTile); if (it != config_map->end()) { const auto &config = it->second; const char *name = config.name.c_str(); const char *co_name = config.co_name.c_str(); auto result = impl_ptr_map.emplace(name, nullptr); if (result.second) { result.first->second = std::make_unique(name, co_name, config.tile_size); } impl_ptr = result.first->second.get(); } else TORCH_CHECK(false, __func__, " Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 128, 192, 256, 320, 384, 448 or 512"); } else { TORCH_CHECK(false, __func__, "No valid kernel selected!"); } } else if (input.dtype() == torch_fp8) { static std::unordered_map multix_kernel_fp8_configs = { {512, {"fmoe_fp8_g1u1_multix_subGU_512", "fmoe_fp8_g1u1_multix_subGU_512.co", 512}}, {448, {"fmoe_fp8_g1u1_multix_subGU_448", "fmoe_fp8_g1u1_multix_subGU_448.co", 448}}, {384, {"fmoe_fp8_g1u1_multix_subGU_384", "fmoe_fp8_g1u1_multix_subGU_384.co", 384}}, {320, {"fmoe_fp8_g1u1_multix_subGU_320", "fmoe_fp8_g1u1_multix_subGU_320.co", 320}}, {256, {"fmoe_fp8_g1u1_multix_subGU_256", "fmoe_fp8_g1u1_multix_subGU_256.co", 256}}, {192, {"fmoe_fp8_g1u1_multix_subGU_192", "fmoe_fp8_g1u1_multix_subGU_192.co", 192}}, {128, {"fmoe_fp8_g1u1_multix_subGU_128", "fmoe_fp8_g1u1_multix_subGU_128.co", 128}}}; static std::unordered_map silu_kernel_fp8_configs = { {512, {"fmoe_fp8_g1u1_subGU_512", "fmoe/silu/fmoe_fp8_g1u1_subGU_512.co", 512}}, {448, {"fmoe_fp8_g1u1_subGU_448", "fmoe/silu/fmoe_fp8_g1u1_subGU_448.co", 448}}, {384, {"fmoe_fp8_g1u1_subGU_384", "fmoe/silu/fmoe_fp8_g1u1_subGU_384.co", 384}}, {320, {"fmoe_fp8_g1u1_subGU_320", "fmoe/silu/fmoe_fp8_g1u1_subGU_320.co", 320}}, {256, {"fmoe_fp8_g1u1_subGU_256", "fmoe/silu/fmoe_fp8_g1u1_subGU_256.co", 256}}, {192, {"fmoe_fp8_g1u1_subGU_192", "fmoe/silu/fmoe_fp8_g1u1_subGU_192.co", 192}}, {128, {"fmoe_fp8_g1u1_subGU_128", "fmoe/silu/fmoe_fp8_g1u1_subGU_128.co", 128}}}; static std::unordered_map gelu_kernel_fp8_configs = { {512, {"fmoe_fp8_g1u1_subGU_512_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_512_gelu.co", 512}}, {448, {"fmoe_fp8_g1u1_subGU_448_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_448_gelu.co", 448}}, {384, {"fmoe_fp8_g1u1_subGU_384_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_384_gelu.co", 384}}, {320, {"fmoe_fp8_g1u1_subGU_320_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_320_gelu.co", 320}}, {256, {"fmoe_fp8_g1u1_subGU_256_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_256_gelu.co", 256}}, {192, {"fmoe_fp8_g1u1_subGU_192_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_192_gelu.co", 192}}, {128, {"fmoe_fp8_g1u1_subGU_128_gelu", "fmoe/gelu/fmoe_fp8_g1u1_subGU_128_gelu.co", 128}}}; int selectedTile = get_heuristic_tile(inter_dim, sub_X_cnt, {512, 448, 384, 320, 256, 192, 128}); std::unordered_map *config_map = nullptr; if (fc2_smooth_scale.has_value()) { config_map = &multix_kernel_fp8_configs; } else if (activation == ActivationType::Gelu) { config_map = &gelu_kernel_fp8_configs; } else if (activation == ActivationType::Silu) { config_map = &silu_kernel_fp8_configs; } if (config_map) { auto it = config_map->find(selectedTile); if (it != config_map->end()) { const auto &config = it->second; const char *name = config.name.c_str(); const char *co_name = config.co_name.c_str(); auto result = impl_ptr_map.emplace(name, nullptr); if (result.second) { result.first->second = std::make_unique(name, co_name, config.tile_size); } impl_ptr = result.first->second.get(); } else TORCH_CHECK(false, __func__, " Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 128, 192, 256, 320, 384, 448 or 512"); } else { TORCH_CHECK(false, __func__, "No valid kernel selected!"); } } else { TORCH_CHECK(false, __func__, " Input only supput Int8/Fp8!"); } impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args input_scale, fc1_scale, fc2_scale, fc2_smooth_scale); } void fmoe_g1u1_tkw1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim*2, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [token_cnt, 1] torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] std::optional fc2_smooth_scale, // [expert, 1, inter_dim] ActivationType activation) { struct FMoeKernelConfig { std::string name; std::string co_name; int tile_size; }; FMoeKernel *impl_ptr = nullptr; int inter_dim = down.size(2); static std::unordered_map> impl_ptr_map; const int token_cnt = input.size(0); const int block_m = 32; // fmoe sorting kernel and fmoe kernel only support 32 for now const int estimated_sub_X_cnt = (token_cnt * topk + block_m - 1) / block_m; if (input.dtype() == torch_fp8) { static std::unordered_map silu_kernel_fp8_configs = { {512, {"fmoe_fp8_g1u1_subGU_512_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_512_silu_tkw1.co", 512}}, {448, {"fmoe_fp8_g1u1_subGU_448_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_448_silu_tkw1.co", 448}}, {384, {"fmoe_fp8_g1u1_subGU_384_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_384_silu_tkw1.co", 384}}, {320, {"fmoe_fp8_g1u1_subGU_320_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_320_silu_tkw1.co", 320}}, {256, {"fmoe_fp8_g1u1_subGU_256_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_256_silu_tkw1.co", 256}}, {192, {"fmoe_fp8_g1u1_subGU_192_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_192_silu_tkw1.co", 192}}, {128, {"fmoe_fp8_g1u1_subGU_128_silu_tkw1", "fmoe/silu/fmoe_fp8_g1u1_subGU_128_silu_tkw1.co", 128}}}; static std::unordered_map gelu_kernel_fp8_configs = { {512, {"fmoe_fp8_g1u1_subGU_512_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_512_gelu_tkw1.co", 512}}, {448, {"fmoe_fp8_g1u1_subGU_448_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_448_gelu_tkw1.co", 448}}, {384, {"fmoe_fp8_g1u1_subGU_384_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_384_gelu_tkw1.co", 384}}, {320, {"fmoe_fp8_g1u1_subGU_320_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_320_gelu_tkw1.co", 320}}, {256, {"fmoe_fp8_g1u1_subGU_256_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_256_gelu_tkw1.co", 256}}, {192, {"fmoe_fp8_g1u1_subGU_192_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_192_gelu_tkw1.co", 192}}, {128, {"fmoe_fp8_g1u1_subGU_128_gelu_tkw1", "fmoe/gelu/fmoe_fp8_g1u1_subGU_128_gelu_tkw1.co", 128}}}; int selectedTile = get_heuristic_tile(inter_dim, estimated_sub_X_cnt, {512, 448, 384, 320, 256, 192, 128}); std::unordered_map *config_map = nullptr; if (fc2_smooth_scale.has_value()) { TORCH_CHECK(false, __func__, " Only supput non-smooth tkw1!"); } else if (activation == ActivationType::Gelu) { config_map = &gelu_kernel_fp8_configs; } else if (activation == ActivationType::Silu) { config_map = &silu_kernel_fp8_configs; } if (config_map) { auto it = config_map->find(selectedTile); if (it != config_map->end()) { const auto &config = it->second; const char *name = config.name.c_str(); const char *co_name = config.co_name.c_str(); auto result = impl_ptr_map.emplace(name, nullptr); if (result.second) { result.first->second = std::make_unique(name, co_name, config.tile_size); } impl_ptr = result.first->second.get(); } else TORCH_CHECK(false, __func__, " Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 128, 192, 256, 320, 384, 448 or 512"); } else { TORCH_CHECK(false, __func__, "No valid kernel selected!"); } } else { TORCH_CHECK(false, __func__, " Unsupported input dtype:", input.dtype()); } impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args input_scale, fc1_scale, fc2_scale, fc2_smooth_scale); } void fmoe_int8_g1u0_a16(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc1_smooth_scale, // [expert, 1, dim] torch::Tensor &fc2_smooth_scale // [expert, 1, inter_dim] ) { static FMoeKernel impl("fmoe_kernel_func", "fmoe_int8_g1u0_smf.co"); impl.launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args fc1_smooth_scale, fc1_scale, fc2_scale, fc2_smooth_scale); } void fmoe_g1u1_a16(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim*2, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] torch::Tensor &fc1_smooth_scale, // [expert, 1, dim] torch::Tensor &fc2_smooth_scale // [expert, 1, inter_dim] ) { FMoeKernel *impl_ptr = nullptr; int inter_dim = down.size(2); int sub_X_cnt = sorted_expert_ids.size(0); if (gate.dtype() == at::ScalarType::Char || gate.dtype() == at::ScalarType::Byte) { TORCH_CHECK(inter_dim % 320 == 0, __func__, "int8 quant Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 320"); static FMoeKernel impl_int8_320("fmoe_int8_g1u1_smf_subGU_320", "fmoe_int8_g1u1_smf_subGU_320.co", 320); impl_ptr = &impl_int8_320; } else if (gate.dtype() == torch_fp8) { int selectedTile = get_heuristic_tile(inter_dim, sub_X_cnt, {512, 320}); // todo,add tune interface here if (selectedTile == 512) { static FMoeKernel impl_fp8_512("fmoe_fp8_g1u1_smf_subGU_512", "fmoe_fp8_g1u1_smf_subGU_512.co", 512); impl_ptr = &impl_fp8_512; } else if (selectedTile == 320) { static FMoeKernel impl_fp8_320("fmoe_fp8_g1u1_smf_subGU_320", "fmoe_fp8_g1u1_smf_subGU_320.co", 320); impl_ptr = &impl_fp8_320; } else TORCH_CHECK(false, __func__, "fp8 quant Unsupported inter_dim " + std::to_string(inter_dim) + ", which should be divisible by 320 or 512"); } else { TORCH_CHECK(false, __func__, " gate/down weight only supput Int8/Fp8!"); } impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args fc1_smooth_scale, fc1_scale, fc2_scale, fc2_smooth_scale); } void fmoe_fp8_blockscale_g1u1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim*2, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] uint32_t topk, // torch::Tensor &input_scale, // [expert, 1, dim] torch::Tensor &fc1_scale, // [expert, 1, inter_dim] torch::Tensor &fc2_scale, // [expert, 1, dim] int fc_scale_blkn, int fc_scale_blkk, std::optional fc2_smooth_scale, ActivationType activation) { FMoeKernel *impl_ptr = nullptr; int inter_dim = down.size(2); int sub_X_cnt = sorted_expert_ids.size(0); // int selectedTile = get_heuristic_tile(inter_dim, sub_X_cnt); // todo,add tune interface here const char *enable_vskip = std::getenv("AITER_ENABLE_VSKIP"); if (out.dtype() == at::ScalarType::BFloat16 && inter_dim % 256 == 0 && fc_scale_blkn == 128 && fc_scale_blkk == 128) { if (enable_vskip != nullptr && strcmp(enable_vskip, "1") == 0) { static FMoeKernel impl_256("_ZN5aiter34fmoe_fp8_blockscale_g1u1_subGU_256E", "/fmoe/fmoe_fp8_blockscale_g1u1_subGU_256.co", 256); impl_ptr = &impl_256; } else { static FMoeKernel impl_256_novs("_ZN5aiter39fmoe_fp8_blockscale_g1u1_novs_subGU_256E", "/fmoe/fmoe_fp8_blockscale_g1u1_novs_subGU_256.co", 256); impl_ptr = &impl_256_novs; } } else TORCH_CHECK(false, __func__, " Only support out dtype = bf16, inter_dim % 256 = 0 and fc_scale_blkn and fc_scale_blkk is 128"); impl_ptr->launch_kernel(out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, topk, // quant args input_scale, fc1_scale, fc2_scale, fc2_smooth_scale); }