/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #include #ifdef __HIP_PLATFORM_AMD__ #include #ifdef USE_HIPBLASLT #include #endif #ifdef USE_ROCBLAS #define ROCBLAS_BETA_FEATURES_API #include #endif #else #include #include #endif // __HIP_PLATFORM_AMD__ #include #ifndef __HIP_PLATFORM_AMD__ #include "nccl.h" #endif #ifdef NVTE_WITH_CUBLASMP #include #endif // NVTE_WITH_CUBLASMP #include #include #include #include "../util/string.h" #define NVTE_WARN(...) \ do { \ std::cerr << ::transformer_engine::concat_strings( \ __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ } while (false) #define NVTE_ERROR(...) \ do { \ throw ::std::runtime_error(::transformer_engine::concat_strings( \ __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ ::transformer_engine::concat_strings(__VA_ARGS__))); \ } while (false) #define NVTE_CHECK(expr, ...) \ do { \ if (!(expr)) { \ NVTE_ERROR("Assertion failed: " #expr ". ", \ ::transformer_engine::concat_strings(__VA_ARGS__)); \ } \ } while (false) #define NCCLCHECK(cmd) do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, \ ncclGetErrorString(r)); \ exit(EXIT_FAILURE); \ } \ } while(0) #define NVTE_CHECK_CUDA(expr) \ do { \ const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ NVTE_ERROR("CUDA Error: ", cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ } \ } while (false) #ifdef __HIP_PLATFORM_AMD__ #ifdef USE_HIPBLASLT //hipblaslt #define NVTE_CHECK_HIPBLASLT(expr) \ do { \ const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ NVTE_ERROR("HIPBLASLT Error: ", \ std::to_string((int)status_NVTE_CHECK_CUBLAS)); \ } \ } while (false) #endif #ifdef USE_ROCBLAS //rocblas #define NVTE_CHECK_ROCBLAS(expr) \ do { \ const rocblas_status status_NVTE_CHECK_CUBLAS = (expr); \ if (status_NVTE_CHECK_CUBLAS != rocblas_status_success) { \ NVTE_ERROR("ROCBLAS Error: " + \ std::string(rocblas_status_to_string(status_NVTE_CHECK_CUBLAS))); \ } \ } while (false) #endif #else //cublas #define NVTE_CHECK_CUBLAS(expr) \ do { \ const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ } \ } while (false) #endif #define NVTE_CHECK_CUDNN(expr) \ do { \ const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ NVTE_ERROR("cuDNN Error: ", cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ ". " \ "For more information, enable cuDNN error logging " \ "by setting CUDNN_LOGERR_DBG=1 and " \ "CUDNN_LOGDEST_DBG=stderr in the environment."); \ } \ } while (false) #define NVTE_CHECK_CUDNN_FE(expr) \ do { \ const auto error = (expr); \ if (error.is_bad()) { \ NVTE_ERROR("cuDNN Error: ", error.err_msg, \ ". " \ "For more information, enable cuDNN error logging " \ "by setting CUDNN_LOGERR_DBG=1 and " \ "CUDNN_LOGDEST_DBG=stderr in the environment."); \ } \ } while (false) #define NVTE_CHECK_NVRTC(expr) \ do { \ const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ } \ } while (false) #ifdef NVTE_WITH_CUBLASMP #define NVTE_CHECK_CUBLASMP(expr) \ do { \ const cublasMpStatus_t status = (expr); \ if (status != CUBLASMP_STATUS_SUCCESS) { \ NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ } \ } while (false) #endif // NVTE_WITH_CUBLASMP #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ } \ } while (false) #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_