"...web/git@developer.sourcefind.cn:wuxk1/dcu-comui.git" did not exist on "000fbec3c9a9093386f5a7d8eae6b4682e9d5084"
Unverified Commit 6b311da2 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Refactor logging macros (#382)



* Do not include logging macros in installed C headers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug logging macros
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug C++ tests

Use Google style for header includes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update CUDA driver macros

Incorporating changes from #389.
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use core error checking macros in PyTorch extensions

Hack to get around macro redefinition warning.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix missing arg when getting CUDA driver error string
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Reuse logging header in frameworks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
parent 91b754e0
...@@ -467,6 +467,7 @@ def setup_pytorch_extension() -> setuptools.Extension: ...@@ -467,6 +467,7 @@ def setup_pytorch_extension() -> setuptools.Extension:
include_dirs = [ include_dirs = [
root_path / "transformer_engine" / "common" / "include", root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine" / "pytorch" / "csrc", root_path / "transformer_engine" / "pytorch" / "csrc",
root_path / "transformer_engine",
root_path / "3rdparty" / "cudnn-frontend" / "include", root_path / "3rdparty" / "cudnn-frontend" / "include",
] ]
...@@ -539,6 +540,7 @@ def setup_paddle_extension() -> setuptools.Extension: ...@@ -539,6 +540,7 @@ def setup_paddle_extension() -> setuptools.Extension:
include_dirs = [ include_dirs = [
root_path / "transformer_engine" / "common" / "include", root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine" / "paddle" / "csrc", root_path / "transformer_engine" / "paddle" / "csrc",
root_path / "transformer_engine",
] ]
# Compiler flags # Compiler flags
......
...@@ -4,16 +4,17 @@ ...@@ -4,16 +4,17 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h> #include <cstring>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <memory>
#include <random> #include <random>
#include <cstring>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h> #include <cmath>
#include <transformer_engine/logging.h> #include <cstring>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory> #include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <random> #include <random>
#include <cstring>
#include <cmath> #include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h> #include <cmath>
#include <transformer_engine/logging.h> #include <cstring>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory> #include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <random> #include <random>
#include <cstring>
#include <cmath> #include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transpose.h>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,11 +4,6 @@ ...@@ -4,11 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
...@@ -16,6 +11,12 @@ ...@@ -16,6 +11,12 @@
#include <memory> #include <memory>
#include <random> #include <random>
#include <type_traits> #include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,11 +4,6 @@ ...@@ -4,11 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
...@@ -16,6 +11,12 @@ ...@@ -16,6 +11,12 @@
#include <memory> #include <memory>
#include <random> #include <random>
#include <type_traits> #include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,18 +4,19 @@ ...@@ -4,18 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cmath> #include <cmath>
#include <cstring>
#include <memory> #include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <random> #include <random>
#include <cstring>
#include <type_traits> #include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,19 @@ ...@@ -4,17 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/layer_norm.h> #include <cmath>
#include <transformer_engine/transformer_engine.h> #include <cstring>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory> #include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <random> #include <random>
#include <cstring>
#include <cmath> #include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cstring> #include <cstring>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <vector> #include <vector>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,19 +4,19 @@ ...@@ -4,19 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "gtest/gtest.h" #include <cstring>
#include <transformer_engine/cast.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <memory>
#include <random> #include <random>
#include <cstring>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,17 +4,19 @@ ...@@ -4,17 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transformer_engine.h>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -4,16 +4,17 @@ ...@@ -4,16 +4,17 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h> #include <cstring>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <memory>
#include <random> #include <random>
#include <cstring>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -6,13 +6,16 @@ ...@@ -6,13 +6,16 @@
#include "test_common.h" #include "test_common.h"
#include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h"
#include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <random> #include <random>
#include <gtest/gtest.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
namespace test { namespace test {
std::vector<DType> all_fp_types = {DType::kFloat32, std::vector<DType> all_fp_types = {DType::kFloat32,
......
...@@ -6,15 +6,17 @@ ...@@ -6,15 +6,17 @@
#pragma once #pragma once
#include <iostream>
#include <memory> #include <memory>
#include <transformer_engine/transformer_engine.h> #include <vector>
#include <transformer_engine/logging.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <vector>
#include <iostream> #include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
namespace test { namespace test {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -252,4 +254,3 @@ bool isFp8Type(DType type); ...@@ -252,4 +254,3 @@ bool isFp8Type(DType type);
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -7,20 +7,22 @@ ...@@ -7,20 +7,22 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include <unordered_map>
#include <functional> #include <functional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <type_traits>
#include <unordered_map>
#include <vector> #include <vector>
#include "nvtx.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
#include "./nvtx.h"
#include "./util/logging.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,18 +4,21 @@ ...@@ -4,18 +4,21 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/softmax.h>
#include <transformer_engine/logging.h>
#include <assert.h> #include <assert.h>
#include <stdint.h> #include <stdint.h>
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "../utils.cuh" #include <cuda_runtime.h>
#include <transformer_engine/softmax.h>
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
#include "../util/logging.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,18 +4,21 @@ ...@@ -4,18 +4,21 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/softmax.h>
#include <transformer_engine/logging.h>
#include <assert.h> #include <assert.h>
#include <stdint.h> #include <stdint.h>
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "../utils.cuh" #include <cuda_runtime.h>
#include <transformer_engine/softmax.h>
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
#include "../util/logging.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h>
#include <transformer_engine/transformer_engine.h>
#include "../common.h" #include "../common.h"
#include "../util/logging.h"
namespace { namespace {
...@@ -259,9 +261,12 @@ void cublas_gemm(const Tensor *inputA, ...@@ -259,9 +261,12 @@ void cublas_gemm(const Tensor *inputA,
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize))); &workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult, Ddesc, preference, 1, &heuristicResult,
&returnedResults)); &returnedResults);
NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_LOGGING_H_
#define TRANSFORMER_ENGINE_LOGGING_H_
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <cudnn.h>
#include <nvrtc.h>
#include <string>
#include <stdexcept>
#define NVTE_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
#define NVTE_CHECK(x, ...) \
do { \
if (!(x)) { \
NVTE_ERROR(std::string("Assertion failed: " #x ". ") + std::string(__VA_ARGS__)); \
} \
} while (false)
namespace {
inline void check_cuda_(cudaError_t status) {
if ( status != cudaSuccess ) {
NVTE_ERROR("CUDA Error: " + std::string(cudaGetErrorString(status)));
}
}
inline void check_cublas_(cublasStatus_t status) {
if ( status != CUBLAS_STATUS_SUCCESS ) {
NVTE_ERROR("CUBLAS Error: " + std::string(cublasGetStatusString(status)));
}
}
inline void check_cudnn_(cudnnStatus_t status) {
if ( status != CUDNN_STATUS_SUCCESS ) {
std::string message;
message.reserve(1024);
message += "CUDNN Error: ";
message += cudnnGetErrorString(status);
message += (". "
"For more information, enable cuDNN error logging "
"by setting CUDNN_LOGERR_DBG=1 and "
"CUDNN_LOGDEST_DBG=stderr in the environment.");
NVTE_ERROR(message);
}
}
inline void check_nvrtc_(nvrtcResult status) {
if ( status != NVRTC_SUCCESS ) {
NVTE_ERROR("NVRTC Error: " + std::string(nvrtcGetErrorString(status)));
}
}
} // namespace
#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); }
#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); }
#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); }
#define NVTE_CHECK_NVRTC(ans) { check_nvrtc_(ans); }
#endif // TRANSFORMER_ENGINE_LOGGING_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment