#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include class CUDAError : public std::runtime_error { public: CUDAError(cudaError_t errorCode, std::source_location location) : std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {} public: const cudaError_t errorCode; const std::source_location location; private: static std::string format(cudaError_t errorCode, std::source_location location) { return spdlog::fmt_lib::format("CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line()); } }; inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location location = std::source_location::current()) { if (retValue != cudaSuccess) { throw CUDAError(retValue, location); } return retValue; } inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue, const std::source_location location = std::source_location::current()) { if (retValue != CUBLAS_STATUS_SUCCESS) { throw std::runtime_error(spdlog::fmt_lib::format("CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line())); } return retValue; } inline thread_local std::stack stackCUDAStreams; inline cudaStream_t getCurrentCUDAStream() { if (stackCUDAStreams.empty()) { return 0; } return stackCUDAStreams.top(); } struct CUDAStreamContext { cudaStream_t stream; CUDAStreamContext(cudaStream_t stream) : stream(stream) { stackCUDAStreams.push(stream); } CUDAStreamContext(const CUDAStreamContext &) = delete; CUDAStreamContext(CUDAStreamContext &&) = delete; ~CUDAStreamContext() { assert(stackCUDAStreams.top() == stream); stackCUDAStreams.pop(); } }; struct CUDAStreamWrapper { cudaStream_t stream; CUDAStreamWrapper() { checkCUDA(cudaStreamCreate(&stream)); } CUDAStreamWrapper(const CUDAStreamWrapper &) = delete; CUDAStreamWrapper(CUDAStreamWrapper &&) = delete; ~CUDAStreamWrapper() { checkCUDA(cudaStreamDestroy(stream)); } }; struct CUDAEventWrapper { cudaEvent_t event; CUDAEventWrapper(unsigned int flags = cudaEventDefault) { checkCUDA(cudaEventCreateWithFlags(&event, flags)); } CUDAEventWrapper(const CUDAEventWrapper &) = delete; CUDAEventWrapper(CUDAEventWrapper &&) = delete; ~CUDAEventWrapper() { checkCUDA(cudaEventDestroy(event)); } }; inline cudaDeviceProp *getCurrentDeviceProperties() { static thread_local cudaDeviceProp prop; static thread_local bool propAvailable = false; if (!propAvailable) { int device; checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDeviceProperties(&prop, device)); propAvailable = true; } return ∝ } template constexpr T ceilDiv(T a, T b) { return (a + b - 1) / b; } template constexpr int log2Up(T value) { if (value <= 0) return 0; if (value == 1) return 0; return log2Up((value + 1) / 2) + 1; } struct CUBLASWrapper { cublasHandle_t handle = nullptr; CUBLASWrapper() { checkCUBLAS(cublasCreate(&handle)); } CUBLASWrapper(CUBLASWrapper &&) = delete; CUBLASWrapper(const CUBLASWrapper &&) = delete; ~CUBLASWrapper() { if (handle) { checkCUBLAS(cublasDestroy(handle)); } } }; inline std::shared_ptr getCUBLAS() { static thread_local std::weak_ptr inst; std::shared_ptr result = inst.lock(); if (result) { return result; } result = std::make_shared(); inst = result; return result; }