// SPDX-License-Identifier: MIT #pragma once #include "aiter_enum.h" #include "ck_tile/core.hpp" #include #include #include #include #include enum class GPUArch { gfx936, gfx938, gfx946, }; namespace aiter_detail { inline thread_local bool g_aiter_can_throw = false; // Fatal (non-recoverable) error handler — used by HIP_CALL. // Always aborts; does not consult g_aiter_can_throw. template [[noreturn, gnu::noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args) { std::cerr << "[AITER] " << file << ":" << line << " "; (std::cerr << ... << std::forward(args)) << std::endl; std::abort(); } template [[noreturn]] inline void check_fail(const char* file, int line, Args&&... args) { std::ostringstream oss; oss << "[AITER] " << file << ":" << line; if constexpr(sizeof...(Args) > 0) { oss << " "; (oss << ... << std::forward(args)); } else { oss << " check failed"; } std::string msg = oss.str(); std::cerr << msg << std::endl; if(g_aiter_can_throw) { throw std::runtime_error(std::move(msg)); } std::abort(); } } // namespace aiter_detail #define AITER_CHECK(x, ...) \ do \ { \ if(!(x)) [[unlikely]] \ { \ aiter_detail::check_fail(__FILE__, __LINE__ __VA_OPT__(, ) __VA_ARGS__); \ } \ } while(0) #define HIP_CALL(call) \ do \ { \ hipError_t err = call; \ if(err != hipSuccess) [[unlikely]] \ { \ aiter_detail::aiter_check_fatal(__FILE__, \ __LINE__, \ "fail to call " #call " ---> [HIP error](", \ hipGetErrorString(err), \ ')'); \ } \ } while(0) struct p3 { unsigned int _p0; unsigned int _p1; unsigned int _p2; }; struct p2 { unsigned int _p0; unsigned int _p1; }; struct p1 { unsigned int _p0; }; struct AiterAsmKernelArgs { void *args_ptr; void *arg_size_ptr; int gdx; int gdy; int gdz; int bdx; int bdy; int bdz; const hipStream_t stream; }; class AiterAsmKernel { private: hipModule_t module; hipFunction_t kernel_func; public: AiterAsmKernel(const char *name, const char *hsaco) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; }; ~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); } void launch_kernel(const AiterAsmKernelArgs &kargs) { void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kargs.args_ptr, HIP_LAUNCH_PARAM_BUFFER_SIZE, kargs.arg_size_ptr, HIP_LAUNCH_PARAM_END}; HIP_CALL(hipModuleLaunchKernel(kernel_func, kargs.gdx, kargs.gdy, kargs.gdz, kargs.bdx, kargs.bdy, kargs.bdz, 0, kargs.stream, nullptr, (void **)&config)); }; }; static const std::string get_gpu_arch() { int device_count; hipError_t err = hipGetDeviceCount(&device_count); if(err != hipSuccess || device_count == 0) { return "No GPU Found"; } hipDeviceProp_t prop; hipGetDeviceProperties(&prop, 0); std::string arch_full = prop.gcnArchName; size_t colon_pos = arch_full.find(':'); if(colon_pos != std::string::npos) { return arch_full.substr(0, colon_pos); } else { return arch_full; } } static const uint32_t get_num_cu_func() { auto get_num_cu_local = []() { hipDevice_t dev; hipDeviceProp_t dev_prop; HIP_CALL(hipGetDevice(&dev)); HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); return dev_prop.multiProcessorCount; }; static const uint32_t num_cu = get_num_cu_local(); return num_cu; } /// RAII guard that saves the current HIP device and restores it on destruction. /// Required by AiterTensor factory methods and any code that temporarily switches devices. class HipDeviceGuard { public: explicit HipDeviceGuard(int device_id) { HIP_CALL(hipGetDevice(&prev_device_)); HIP_CALL(hipSetDevice(device_id)); } ~HipDeviceGuard() noexcept { HIP_CALL(hipSetDevice(prev_device_)); } HipDeviceGuard(const HipDeviceGuard&) = delete; HipDeviceGuard& operator=(const HipDeviceGuard&) = delete; private: int prev_device_{}; };