// SPDX-License-Identifier: MIT #include "asm_gemm_kernel_config.h" #include #include #include #include #include #include "aiter_hip_common.h" #include //#include "py_itfs_common.h" //#define DEBUG_BUFFER struct __attribute__((packed)) AwqGemmAsmArgs { uint32_t gemm_count; uint32_t internalArgs; uint32_t internalArgs1; uint32_t numWorkGroups; uint32_t m; //!< size m uint32_t n; //!< size n uint32_t batch; //!< size batch uint32_t k; //!< size k void *d; //!< The d matrix input pointer. void *c; //!< The c matrix input pointer. void *a; //!< The a matrix input pointer. void *b; //!< The b matrix input pointer. uint32_t strideD1; //!< The d leading dimension. uint32_t strideD2; //!< The d batch stride uint32_t strideC1; //!< The c leading dimension. uint32_t strideC2; //!< The c batch stride uint32_t strideA1; //!< The a leading dimension. uint32_t strideA2; //!< The a batch stride uint32_t strideB1; //!< The b leading dimension. uint32_t strideB2; //!< The b batch stride float alpha; //!< The alpha value. float beta; //!< The beta value. void *debugBuffer; //!< The d matrix input pointer. void *dstD; //!< The c matrix input pointer. void *Synchronizer; //!< The a matrix input pointer. uint32_t GSUSync; //!< The b matrix input pointer. }; struct KernelConfigs { uint32_t mt0; uint32_t mt1; uint32_t numThreads; uint32_t wgm; }; class AwqGemmAsmKernel { private: hipModule_t module; hipFunction_t kernel_func; public: AwqGemmAsmKernel(const char *name, const char *hsaco) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; }; size_t debugBufferElementsPerThread = 16; size_t debugBufferSize = 0; std::shared_ptr debugBufferHostPtr; unsigned int* debugBufferDevicePtr = nullptr; void CreateDebugBuffer(size_t numWorkGroups, size_t numThreads) { std::cout << "Smart Lt debugKernel is enabled !!!! " << std::endl; size_t debugBufferNumElem = debugBufferElementsPerThread; debugBufferNumElem *= numWorkGroups; debugBufferNumElem *= numThreads; debugBufferSize = debugBufferNumElem * 4; hipMalloc(&debugBufferDevicePtr, debugBufferSize); debugBufferHostPtr = std::shared_ptr( (unsigned int *)std::malloc(debugBufferSize), std::free); memset(debugBufferHostPtr.get(), 0, debugBufferSize); hipMemcpy(debugBufferDevicePtr, debugBufferHostPtr.get(), debugBufferSize, hipMemcpyHostToDevice); }; void debug_buffer_print() { hipMemcpy(debugBufferHostPtr.get(), debugBufferDevicePtr, debugBufferSize, hipMemcpyDeviceToHost); unsigned int * dbg_ptr = debugBufferHostPtr.get(); const char *field_names[16] = { "tid","wg0","wg1","groA","groB", "lraA","lraB","lwaA","lwaB"}; for (unsigned int i = 0; i < debugBufferSize / 4 / debugBufferElementsPerThread; i++) { if (i % 64 == 0) { printf("\n"); for (unsigned int j = 0; j < debugBufferElementsPerThread; j++) { printf("%12s,", field_names[j]); } printf("\n"); } char flags[16] = {'u', 'u', 'u', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x'}; //if((i%64) < 4 || (i%64) >= 60) //if(i > 768) { for (unsigned int j = 0; j < debugBufferElementsPerThread; j++) { if (flags[j] == 'u') printf(" %8u,", dbg_ptr[i * debugBufferElementsPerThread + j]); else if (flags[j] == 'x') printf(" 0x%08x,", dbg_ptr[i * debugBufferElementsPerThread + j]); else if (flags[j] == 'f') printf(" %8.4f,", ((float *)dbg_ptr)[i * debugBufferElementsPerThread + j]); else if (flags[j] == 'd') printf(" %8d,", dbg_ptr[i * debugBufferElementsPerThread + j]); } printf("\n"); } } printf("\n"); }; void launch_kernel(bool isFused, torch::Tensor &out, torch::Tensor &mat1, // [token_cnt, dim] std::optional &zero, std::optional &scale, KernelConfigs *Kconfigs = nullptr ) { AwqGemmAsmArgs userArg_h; size_t arg_size = sizeof(userArg_h); void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &userArg_h, HIP_LAUNCH_PARAM_BUFFER_SIZE, &arg_size, HIP_LAUNCH_PARAM_END}; //Kconfigs->mt0; //Kconfigs->mt1; //Kconfigs->numThreads; //Kconfigs->wgm; userArg_h.gemm_count = 1; userArg_h.internalArgs = 0x00200001; userArg_h.internalArgs1 = 1; userArg_h.numWorkGroups = 1; userArg_h.batch = 1; userArg_h.m = mat1.size(1); userArg_h.n = 16; userArg_h.k = mat1.size(0); userArg_h.strideD1 = out.size(1); userArg_h.strideD2 = out.size(1); userArg_h.strideC1 = out.size(1); userArg_h.strideC2 = out.size(1); userArg_h.strideB1 = 16; userArg_h.strideB2 = 16; if(isFused) { userArg_h.strideB1 = 16; userArg_h.strideB2 = 16; userArg_h.dstD = zero->data_ptr(); userArg_h.Synchronizer = scale->data_ptr(); //std::cout << "userArg_h.zero: " << std::hex << zero->data_ptr() << std::dec << std::endl; //std::cout << "userArg_h.Synchronizer: " << std::hex << scale->data_ptr() << std::dec << std::endl; //std::cout <<"m, n, k, strideD1 " << userArg_h.m << ", " << userArg_h.n << ", " << userArg_h.k << out.size(1) << std::endl; } size_t wg0 = (userArg_h.m + Kconfigs->mt0 - 1) / Kconfigs->mt0; size_t wg1 = (userArg_h.n + Kconfigs->mt1 - 1) / Kconfigs->mt1; userArg_h.numWorkGroups = wg0 * wg1 ; userArg_h.d = out.data_ptr(); userArg_h.c = out.data_ptr(); userArg_h.a = mat1.data_ptr(); userArg_h.b = mat1.data_ptr(); // no use userArg_h.strideA1 = mat1.size(1); userArg_h.strideA2 = mat1.size(1); userArg_h.alpha = 1.0; userArg_h.beta = 0.0; userArg_h.debugBuffer = nullptr; userArg_h.GSUSync = 0; #if 0 std::cout << "userArg_h.m: " << userArg_h.m << std::endl; std::cout << "userArg_h.n: " << userArg_h.n << std::endl; std::cout << "userArg_h.k: " << userArg_h.k << std::endl; std::cout << "userArg_h.batch: " << userArg_h.batch << std::endl; std::cout << "userArg_h.a: " << userArg_h.a << std::endl; std::cout << "userArg_h.b: " << userArg_h.b << std::endl; std::cout << "userArg_h.c: " << userArg_h.c << std::endl; std::cout << "userArg_h.d: " << userArg_h.d << std::endl; std::cout << "userArg_h.strideD1: " << userArg_h.strideD1 << std::endl; std::cout << "userArg_h.strideD2: " << userArg_h.strideD2 << std::endl; std::cout << "userArg_h.strideC1: " << userArg_h.strideC1 << std::endl; std::cout << "userArg_h.strideC2: " << userArg_h.strideC2 << std::endl; std::cout << "userArg_h.strideA1: " << userArg_h.strideA1 << std::endl; std::cout << "userArg_h.strideA2: " << userArg_h.strideA2 << std::endl; std::cout << "userArg_h.strideB1: " << userArg_h.strideB1 << std::endl; std::cout << "userArg_h.strideB2: " << userArg_h.strideB2 << std::endl; std::cout << "userArg_h.alpha: " << userArg_h.alpha << std::endl; std::cout << "userArg_h.beta: " << userArg_h.beta << std::endl; #endif int bdx = Kconfigs->numThreads; int gdx = userArg_h.numWorkGroups; int gdy = 1; int gdz = 1; //std::cout <<"L256, bdx= " << bdx << ", gdx= " << gdx << std::endl; #ifdef DEBUG_BUFFER CreateDebugBuffer(gdx, bdx); userArg_h.debugBuffer = debugBufferDevicePtr; #endif const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); HIP_CALL(hipModuleLaunchKernel(kernel_func, gdx, gdy, gdz, bdx, 1, 1, 0, stream, nullptr, (void **)&config)); #ifdef DEBUG_BUFFER debug_buffer_print(); #endif }; }; static std::unordered_map> g_kernel_cache; static std::mutex g_kernel_mu; KernelCfg cfg_dq { "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32_dq", "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_dequant.co", 64, 32, 768, 1 }; KernelCfg cfg_dq_bf16 { "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32_dq", "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_dequant.co", 64, 32, 768, 1 }; static AwqGemmAsmKernel* get_or_create_kernel(const KernelCfg& c) { std::lock_guard lk(g_kernel_mu); if (auto it = g_kernel_cache.find(c.kernel_name); it != g_kernel_cache.end()) return it->second.get(); auto ptr = std::make_unique(c.kernel_name.c_str(), c.co_file.c_str()); auto* raw = ptr.get(); g_kernel_cache.emplace(c.kernel_name, std::move(ptr)); return raw; } static constexpr const char* GROUP_AWQ_DEFAULT = "w4a16"; static constexpr const char* GROUP_AWQ_FUSED = "fused_w4a16"; void awq_dq_asm(torch::Tensor &out, torch::Tensor &mat1, std::optional &zero, std::optional &scalar ) { const bool isFused = zero.has_value() && zero->defined() && zero->numel() > 0; // std::cout << "isFused=" << isFused<< std::endl; const char* GROUP_AWQ = isFused ? GROUP_AWQ_FUSED : GROUP_AWQ_DEFAULT; if (!isFused && zero.has_value()) { std::cerr << "[awq_dq_asm][warn] zero is null/empty; using fused group '" << GROUP_AWQ_FUSED << "'\n"; } AwqGemmAsmKernel *impl_ptr = nullptr; KernelConfigs Kconfigs; if (mat1.dtype() == at::ScalarType::Char) { std::string key; KernelCfg cfg = cfg_dq; if (scalar->dtype() == at::ScalarType::BFloat16 || scalar->dtype() == torch::kBFloat16) cfg = cfg_dq_bf16; impl_ptr = get_or_create_kernel(cfg); Kconfigs.mt0 = cfg.mt0; Kconfigs.mt1 = cfg.mt1; Kconfigs.numThreads = cfg.numThreads; Kconfigs.wgm = cfg.wgm; TORCH_CHECK(impl_ptr != nullptr, __func__, ": unsupport current input type:", mat1.scalar_type()); impl_ptr->launch_kernel(isFused, out, mat1, zero, scalar, &Kconfigs); } else { TORCH_CHECK(false, "awq_dq_asm: dtype not supported yet"); } }