Unverified Commit e9022290 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Support for NVRTC kernels (#138)



* Initial implementation of NVRTC infrastructure
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial NVRTC impl for transpose

NVRTC gives compilation errors at runtime. Everything else compiles and passes tests as expected.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug NVRTC transpose impl

NVRTC kernel compiles, runs, and passes tests with FP32.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use variadic template for kernel arguments in RTC kernel launch func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactoring

Added utility header for CUDA Runtime API. Optimized concat_strings function.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add helper function for regex substitutions in strings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add option to disable NVRTC support
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for header includes in NVRTC kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Access lazily-initialized CUDA driver lib and add option to specify CUDA header dir
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Configure NVRTC transpose kernel with simple perf model
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert change to tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Style fixes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add prime-valued test cases
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix multiple definition error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Optimize NVRTC transpose kernel for small data sizes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Mention NVRTC in docs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit tests for NVRTC and string utils
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add comment in install docs about NVRTC

Review suggestion from @nouiz
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug perf model for RTC transpose kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove NVRTC discussion from docs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Require CUDA headers unless NVRTC is explicitly disabled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use diagonal coords in transpose kernel to avoid partition camping
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use std::call_once for thread-safety
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor fixes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug CMake error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary call_once
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove diagonal coordinates from transpose kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use size_t indices instead of int
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @ptrendx

Check build-time CUDA include path for run-time CUDA headers. Handle case where CUDA context is initially uninitialized.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 0d251991
...@@ -17,6 +17,9 @@ Prerequisites ...@@ -17,6 +17,9 @@ Prerequisites
4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later. 4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later. 5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later.
If the CUDA Toolkit headers are not available at runtime in a standard
installation path, e.g. within `CUDA_HOME`, set
`NVTE_CUDA_INCLUDE_PATH` in the environment.
Transformer Engine in NGC Containers Transformer Engine in NGC Containers
------------------------------------ ------------------------------------
......
...@@ -30,8 +30,10 @@ find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PAT ...@@ -30,8 +30,10 @@ find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PAT
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
add_subdirectory(operator) add_subdirectory(operator)
add_subdirectory(util)
...@@ -60,7 +60,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288}, ...@@ -60,7 +60,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{65536, 128}, {65536, 128},
{256, 256}, {256, 256},
{120, 2080}, {120, 2080},
{8, 8}}; {8, 8},
{1223, 1583}, // Primes 200, 250
{1, 541}, // Prime 100
{1987, 1}}; // Prime 300
} // namespace } // namespace
class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType, class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_executable(test_util
test_nvrtc.cpp
test_string.cpp
../test_common.cu)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
target_compile_options(test_util PRIVATE -O2)
include(GoogleTest)
gtest_discover_tests(test_util)
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <stdexcept>
#include <vector>
#include <gtest/gtest.h>
#include "util/rtc.h"
using namespace transformer_engine;
TEST(UtilTest, NVRTC) {
if (!rtc::is_enabled()) {
GTEST_SKIP() << "NVRTC not enabled, skipping tests";
}
// GPU data buffer
int *device_buffer;
std::vector<int> host_buffer(2);
cudaMalloc((void**)&device_buffer, 2*sizeof(int)); // NOLINT(*)
cudaMemset(device_buffer, 0, 2*sizeof(int));
// CUDA kernel implementations
const char code1[] = R"code(
#include <cuda_runtime.h>
__global__ void my_kernel(int2 *data) {
data->x = 123;
data->y = -456;
}
)code";
const char code2[] = R"code(
#include "utils.cuh"
__global__ void my_kernel(uint32_t *data) {
data[0] = 789;
data[1] = 12;
}
)code";
// Make sure kernels are not available
auto& nvrtc_manager = rtc::KernelManager::instance();
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
device_buffer),
std::runtime_error);
EXPECT_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0,
device_buffer),
std::runtime_error);
// Compile and run first kernel
EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel1",
"my_kernel",
code1,
"test_nvrtc_kernel1.cu"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
device_buffer));
EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
cudaMemcpyDeviceToHost),
cudaSuccess);
EXPECT_EQ(host_buffer[0], 123);
EXPECT_EQ(host_buffer[1], -456);
// Compile and run second kernel
EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel2",
"my_kernel",
code2,
"test_nvrtc_kernel2.cu"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel2"));
EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0, device_buffer));
EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
cudaMemcpyDeviceToHost),
cudaSuccess);
EXPECT_EQ(host_buffer[0], 789);
EXPECT_EQ(host_buffer[1], 12);
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <string>
#include <gtest/gtest.h>
#include "util/string.h"
using namespace transformer_engine;
TEST(UtilTest, ToStringLike) { // to_string_like
// Strings
using namespace std::string_literals;
EXPECT_EQ(to_string_like(std::string("")), "");
EXPECT_EQ(to_string_like(""), "");
EXPECT_EQ(to_string_like(std::string("Hello")), "Hello");
EXPECT_EQ(to_string_like("world!"), "world!");
EXPECT_EQ(to_string_like(" \0\n\\\t\"\' This is a weird C++ string"s),
" \0\n\\\t\"\' This is a weird C++ string"s);
EXPECT_EQ(to_string_like(" Here is a bizarre C string \n\\\t\"\'"),
" Here is a bizarre C string \n\\\t\"\'");
// Zero integer types
EXPECT_EQ(to_string_like(19), "19");
EXPECT_EQ(to_string_like(static_cast<char>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<short int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned short int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<long long int>(0)), "0");
EXPECT_EQ(to_string_like(static_cast<unsigned long long int>(0)), "0");
// Non-zero integer types
EXPECT_EQ(to_string_like(static_cast<char>(1)), "1");
EXPECT_EQ(to_string_like(static_cast<char>(-1)), "-1");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(2)), "2");
EXPECT_EQ(to_string_like(static_cast<short>(3)), "3");
EXPECT_EQ(to_string_like(static_cast<short>(-5)), "-5");
EXPECT_EQ(to_string_like(static_cast<unsigned short>(8)), "8");
EXPECT_EQ(to_string_like(static_cast<int>(13)), "13");
EXPECT_EQ(to_string_like(static_cast<int>(-21)), "-21");
EXPECT_EQ(to_string_like(static_cast<unsigned int>(34)), "34");
EXPECT_EQ(to_string_like(static_cast<long long>(55)), "55");
EXPECT_EQ(to_string_like(static_cast<long long>(-89)), "-89");
EXPECT_EQ(to_string_like(static_cast<unsigned long long>(144)), "144");
EXPECT_EQ(to_string_like(static_cast<size_t>(233)), "233");
// Floating-point types
EXPECT_EQ(std::stof(to_string_like(0.f)), 0.f);
EXPECT_EQ(std::stod(to_string_like(0.)), 0.);
EXPECT_EQ(std::stof(to_string_like(1.25f)), 1.25f);
EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f);
EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25);
EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5);
}
TEST(UtilTest, ConcatStringsTest) { // concat_strings
// Strings
EXPECT_EQ(concat_strings(), "");
EXPECT_EQ(concat_strings(std::string("")), "");
EXPECT_EQ(concat_strings(""), "");
EXPECT_EQ(concat_strings(std::string(""), "", std::string(""), ""), "");
EXPECT_EQ(concat_strings("C string"), "C string");
EXPECT_EQ(concat_strings(std::string("C++ string")), "C++ string");
EXPECT_EQ(concat_strings("C string ", std::string("and"),
" ", std::string("C++ string")),
"C string and C++ string");
// Numeric types
EXPECT_EQ(concat_strings("int ", static_cast<int>(-123),
", uint ", static_cast<unsigned int>(456)),
"int -123, uint 456");
EXPECT_EQ(concat_strings("char ", static_cast<char>(13),
", uchar ", static_cast<unsigned char>(24)),
"char 13, uchar 24");
EXPECT_EQ(concat_strings("int16 ", static_cast<short>(-35),
", uint16 ", static_cast<unsigned short>(46)),
"int16 -35, uint16 46");
EXPECT_EQ(concat_strings("int64 ", static_cast<long long>(57),
", uint64 ", static_cast<unsigned long long>(68)),
"int64 57, uint64 68");
EXPECT_EQ(std::stof(concat_strings("-", 3.25f)), -3.25f);
EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f);
EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25);
EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5);
}
TEST(UtilTest, RegexReplaceTest) { // regex_replace
EXPECT_EQ(regex_replace("this test FAILS", "FAILS", "PASSES"),
"this test PASSES");
EXPECT_EQ(regex_replace("status = 0000", "0", 1), "status = 1111");
EXPECT_EQ(regex_replace("I um sound um \t very umconfident", R"(um\s*)", ""),
"I sound very confident");
}
...@@ -24,6 +24,12 @@ list(APPEND transformer_engine_SOURCES ...@@ -24,6 +24,12 @@ list(APPEND transformer_engine_SOURCES
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu util/cast.cu
util/cuda_driver.cpp
util/cuda_runtime.cpp
util/rtc.cpp
util/system.cpp
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_upper_triang_masked_softmax.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
...@@ -33,13 +39,37 @@ target_include_directories(transformer_engine PUBLIC ...@@ -33,13 +39,37 @@ target_include_directories(transformer_engine PUBLIC
# Configure dependencies # Configure dependencies
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
CUDA::cuda_driver
CUDA::cudart CUDA::cudart
CUDA::nvrtc
CUDA::nvToolsExt CUDA::nvToolsExt
CUDNN::cudnn) CUDNN::cudnn)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include") target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")
# Make header files with C++ strings
function(make_string_header STRING STRING_NAME)
configure_file(util/string_header.h.in
"string_headers/${STRING_NAME}.h"
@ONLY)
endfunction()
function(make_string_header_from_file file_ STRING_NAME)
file(READ "${file_}" STRING)
configure_file(util/string_header.h.in
"string_headers/${STRING_NAME}.h"
@ONLY)
endfunction()
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers")
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
......
...@@ -59,6 +59,22 @@ using bf16 = nv_bfloat16; ...@@ -59,6 +59,22 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
namespace detail {
template <typename T>
constexpr inline const char *type_name() noexcept;
#define TRANSFORMER_ENGINE_TYPE_NAME(T) \
template <> inline constexpr const char *type_name<T>() noexcept { return #T; }
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#undef TRANSFORMER_ENGINE_TYPE_NAME
} // namespace detail
template <typename T> template <typename T>
struct TypeInfo{ struct TypeInfo{
...@@ -96,6 +112,7 @@ struct TypeInfo{ ...@@ -96,6 +112,7 @@ struct TypeInfo{
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = sizeof(T);
constexpr static const char *name = detail::type_name<T>();
}; };
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h>
#include <string> #include <string>
#include <stdexcept> #include <stdexcept>
...@@ -54,6 +55,12 @@ inline void check_cudnn_(cudnnStatus_t status) { ...@@ -54,6 +55,12 @@ inline void check_cudnn_(cudnnStatus_t status) {
} }
} }
inline void check_nvrtc_(nvrtcResult status) {
if ( status != NVRTC_SUCCESS ) {
NVTE_ERROR("NVRTC Error: " + std::string(nvrtcGetErrorString(status)));
}
}
} // namespace } // namespace
#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } #define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); }
...@@ -62,4 +69,6 @@ inline void check_cudnn_(cudnnStatus_t status) { ...@@ -62,4 +69,6 @@ inline void check_cudnn_(cudnnStatus_t status) {
#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); } #define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); }
#define NVTE_CHECK_NVRTC(ans) { check_nvrtc_(ans); }
#endif // TRANSFORMER_ENGINE_LOGGING_H_ #endif // TRANSFORMER_ENGINE_LOGGING_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using Type = __TYPE__;
constexpr size_t load_size = __LOAD_SIZE__;
constexpr size_t store_size = __STORE_SIZE__;
constexpr size_t warps_per_tile = __WARPS_PER_TILE__;
constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void
__launch_bounds__(block_size)
transpose_optimized_kernel(const Type * __restrict__ const input,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(Type);
constexpr size_t nvec_out = store_size / sizeof(Type);
using IVec = Vec<Type, nvec_in>;
using OVec = Vec<Type, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr size_t bdimx = THREADS_PER_WARP;
constexpr size_t bdimy = warps_per_tile;
const size_t tid = threadIdx.x;
const size_t tidx = tid % bdimx;
const size_t tidy = tid / bdimx;
const size_t bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
// Position of tile within tensor
const size_t num_tiles_m = num_rows / tile_dim_m;
const size_t tile_id_m = bid % num_tiles_m;
const size_t tile_id_n = bid / num_tiles_m;
const size_t tile_row = tile_id_m * tile_dim_m;
const size_t tile_col = tile_id_n * tile_dim_n;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
// Load input to registers and transpose
// Note: Each thread loads num_iterations subtiles and transposes in
// registers.
OVec local_output[nvec_in][num_iterations];
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
}
}
}
// Copy from registers to shared memory to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter];
}
__syncthreads();
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
shared_output[j1][i1].store_to(&output[col * num_rows + row]);
}
__syncthreads();
}
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <dlfcn.h>
#include <filesystem>
#include "../common.h"
#include "../util/cuda_runtime.h"
namespace transformer_engine {
namespace {
/*! \brief Wrapper class for a shared library
*
* \todo Windows support
*/
class Library {
public:
explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif // _WIN32 or _WIN64 or __WINDOW__
}
~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
#else
if (handle_ != nullptr) {
dlclose(handle_);
}
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library&) = delete; // move-only
Library(Library&& other) noexcept {
swap(*this, other);
}
Library& operator=(Library other) noexcept {
// Copy-and-swap idiom
swap(*this, other);
return *this;
}
friend void swap(Library& first, Library& second) noexcept;
void *get() noexcept {
return handle_;
}
const void *get() const noexcept {
return handle_;
}
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
void *ptr = dlsym(handle_, symbol);
NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
return ptr;
#endif // _WIN32 or _WIN64 or __WINDOW__
}
private:
void *handle_ = nullptr;
};
void swap(Library& first, Library& second) noexcept {
using std::swap;
swap(first.handle_, second.handle_);
}
/*! \brief Lazily-initialized shared library for CUDA driver */
Library& cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll";
#else
constexpr char lib_name[] = "libcuda.so.1";
#endif
static Library lib(lib_name);
return lib;
}
} // namespace
namespace cuda_driver {
void *get_symbol(const char *symbol) {
return cuda_driver_lib().get_symbol(symbol);
}
} // namespace cuda_driver
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#include <string>
#include <cuda.h>
#include "../common.h"
#include "../util/string.h"
namespace transformer_engine {
namespace cuda_driver {
/*! \brief Get pointer corresponding to symbol in CUDA driver library */
void *get_symbol(const char *symbol);
/*! \brief Call function in CUDA driver library
*
* The CUDA driver library (libcuda.so.1 on Linux) may be different at
* compile-time and run-time. In particular, the CUDA Toolkit provides
* stubs for the driver library in case compilation is on a system
* without GPUs. Indirect function calls into a lazily-initialized
* library ensures we are accessing the correct version.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT*>(get_symbol(symbol));
return (*func)(args...);
}
} // namespace cuda_driver
} // namespace transformer_engine
namespace {
/*! \brief Throw exception if CUDA driver call has failed */
inline void check_cuda_driver_(CUresult status) {
if (status != CUDA_SUCCESS) {
const char *description;
transformer_engine::cuda_driver::call("cuGetErrorString", &description);
NVTE_ERROR(transformer_engine::concat_strings("CUDA Error: ", description));
}
}
/*! \brief Call CUDA driver function and throw exception if it fails */
template <typename... ArgTs>
inline void call_and_check_cuda_driver_(const char *symbol,
ArgTs &&... args) {
check_cuda_driver_(transformer_engine::cuda_driver::call(symbol,
std::forward<ArgTs>(args)...));
}
} // namespace
#define NVTE_CHECK_CUDA_DRIVER(ans) { check_cuda_driver_(ans); }
#define NVTE_CALL_CHECK_CUDA_DRIVER(func, ...) \
{ call_and_check_cuda_driver_(#func, __VA_ARGS__); }
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <filesystem>
#include <mutex>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
namespace transformer_engine {
namespace cuda {
namespace {
// String with build-time CUDA include path
#include "string_path_cuda_include.h"
} // namespace
int num_devices() {
auto query_num_devices = [] () -> int {
int count;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
return count;
};
static int num_devices_ = query_num_devices();
return num_devices_;
}
int current_device() {
// Return 0 if CUDA context is not initialized
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
return 0;
}
// Query device from CUDA runtime
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
return device_id;
}
int sm_arch(int device_id) {
static std::vector<int> cache(num_devices(), -1);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = 10*prop.major + prop.minor;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
int sm_count(int device_id) {
static std::vector<int> cache(num_devices(), -1);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.multiProcessorCount;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
const std::string &include_directory(bool required) {
static std::string path;
// Update cached path if needed
static bool need_to_check_env = true;
if (path.empty() && required) {
need_to_check_env = true;
}
if (need_to_check_env) {
// Search for CUDA headers in common paths
using Path = std::filesystem::path;
std::vector<std::pair<std::string, Path>> search_paths = {
{"NVTE_CUDA_INCLUDE_DIR", ""},
{"CUDA_HOME", ""},
{"CUDA_DIR", ""},
{"", string_path_cuda_include},
{"", "/usr/local/cuda"}};
for (auto &[env, p] : search_paths) {
if (p.empty()) {
p = getenv<Path>(env);
}
if (!p.empty()) {
if (file_exists(p / "cuda_runtime.h")) {
path = p;
break;
}
if (file_exists(p / "include" / "cuda_runtime.h")) {
path = p / "include";
break;
}
}
}
// Throw exception if path is required but not found
if (path.empty() && required) {
std::string message;
message.reserve(2048);
message += "Could not find cuda_runtime.h in";
bool is_first = true;
for (const auto &[env, p] : search_paths) {
message += is_first ? " " : ", ";
is_first = false;
if (!env.empty()) {
message += env;
message += "=";
}
if (p.empty()) {
message += "<unset>";
} else {
message += p;
}
}
message += (". "
"Specify path to CUDA Toolkit headers "
"with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
NVTE_ERROR(message);
}
need_to_check_env = false;
}
// Return cached path
return path;
}
} // namespace cuda
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#include <cuda_runtime_api.h>
namespace transformer_engine {
namespace cuda {
/* \brief Number of accessible devices */
int num_devices();
/* \brief Which device is currently being used */
int current_device();
/* \brief Compute capability of device
*
* \param[in] device_id CUDA device (default is current device)
*
* \return Compute capability as int. Last digit is minor revision,
* remaining digits are major revision.
*/
int sm_arch(int device_id = -1);
/* \brief Number of multiprocessors on a device
*
* \param[in] device_id CUDA device (default is current device)
*
* \return Number of multiprocessors
*/
int sm_count(int device_id = -1);
/* \brief Path to CUDA Toolkit headers
*
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
* environment. Otherwise searches in common install paths.
*
* \param[in] required Whether to throw exception if not found
*
* \return Path to include directory, or an empty string if not found
*/
const std::string &include_directory(bool required = false);
} // namespace cuda
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstdlib>
#include <iostream>
#include <utility>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/string.h"
#include "../util/system.h"
#include "../util/rtc.h"
namespace transformer_engine {
namespace rtc {
namespace {
// Strings with headers for RTC kernels
#include "string_code_utils_cuh.h"
/*! \brief Latest compute capability that NVRTC supports
*
* \return Compute capability as int. Last digit is minor revision,
* remaining digits are major revision.
*/
inline int max_supported_sm_arch() {
static int arch_ = -1;
if (arch_ < 0) {
int num_archs = 0;
NVTE_CHECK_NVRTC(nvrtcGetNumSupportedArchs(&num_archs));
NVTE_CHECK(num_archs > 0, "Could not determine SM archs that NVRTC supports");
std::vector<int> archs(num_archs);
NVTE_CHECK_NVRTC(nvrtcGetSupportedArchs(archs.data()));
arch_ = archs.back();
}
return arch_;
}
} // namespace
bool is_enabled() {
static bool is_enabled_ = false;
static bool need_to_check_env = true;
if (need_to_check_env) {
is_enabled_ = !getenv<bool>("NVTE_DISABLE_NVRTC");
need_to_check_env = false;
}
return is_enabled_;
}
Kernel::Kernel(std::string mangled_name, std::string compiled_code)
: mangled_name_{std::move(mangled_name)}
, compiled_code_{std::move(compiled_code)}
, modules_(cuda::num_devices(), null_module)
, functions_(cuda::num_devices(), null_function)
, init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {
}
Kernel::~Kernel() {
for (int device_id=0; device_id<static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed
if (modules_[device_id] != null_module) {
CUdevice device;
CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id)
!= CUDA_SUCCESS) {
continue;
}
if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device)
!= CUDA_SUCCESS) {
continue;
}
cuda_driver::call("cuModuleUnload", modules_[device_id]);
cuda_driver::call("cuDevicePrimaryCtxRelease", device);
}
}
}
Kernel::Kernel(Kernel&& other) noexcept {
swap(*this, other);
}
Kernel& Kernel::operator=(Kernel other) noexcept {
// Copy-and-swap idiom
swap(*this, other);
return *this;
}
void swap(Kernel& first, Kernel& second) noexcept {
using std::swap;
swap(first.mangled_name_, second.mangled_name_);
swap(first.compiled_code_, second.compiled_code_);
swap(first.modules_, second.modules_);
swap(first.functions_, second.functions_);
swap(first.init_flags_, second.init_flags_);
}
CUfunction Kernel::get_function(int device_id) {
// Load kernel on device if needed
auto load_on_device = [&] () {
// Set driver context to proper device
CUdevice device;
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
// Load function into driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx,
&modules_[device_id],
compiled_code_.c_str(),
0, // numOptions
nullptr, // options
nullptr); // optionValues
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction,
&functions_[device_id],
modules_[device_id],
mangled_name_.c_str());
// Reset driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
};
std::call_once(init_flags_->at(device_id), load_on_device);
// Return CUDA function
return functions_[device_id];
}
KernelManager& KernelManager::instance() {
NVTE_CHECK(is_enabled(), "NVRTC support is not enabled");
static KernelManager instance_;
return instance_;
}
void KernelManager::compile(const std::string &kernel_label,
const std::string &kernel_name,
const std::string &code,
const std::string &filename) {
std::lock_guard<std::mutex> lock_guard_(lock_);
// Choose whether to compile to PTX or cubin
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);
// Compilation flags
std::vector<std::string> opts = {
#if NDEBUG == 0
"-G",
#endif
"--std=c++17"};
if (compile_ptx) {
opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
} else {
opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch));
}
opts.push_back(concat_strings("-I", cuda::include_directory(true)));
std::vector<const char*> opts_ptrs;
for (const auto& opt : opts) {
opts_ptrs.push_back(opt.c_str());
}
// Compile source
nvrtcProgram program;
constexpr int num_headers = 1;
constexpr const char* headers[num_headers] = {string_code_utils_cuh};
constexpr const char* include_names[num_headers] = {"utils.cuh"};
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program,
code.c_str(),
filename.c_str(),
num_headers,
headers,
include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
const nvrtcResult compile_result = nvrtcCompileProgram(program,
opts_ptrs.size(),
opts_ptrs.data());
if (compile_result != NVRTC_SUCCESS) {
// Display log if compilation failed
std::string log = concat_strings("NVRTC compilation log for ",
filename, ":\n");
const size_t log_offset = log.size();
size_t log_size;
NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
log.resize(log_offset + log_size);
NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset]));
log.back() = '\n';
std::cerr << log;
NVTE_CHECK_NVRTC(compile_result);
}
// Get mangled function name
const char *mangled_name;
NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program,
kernel_name.c_str(),
&mangled_name));
// Get compiled code
std::string compiled_code;
if (compile_ptx) {
size_t compiled_size;
NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size));
compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(nvrtcGetPTX(program, compiled_code.data()));
} else {
size_t compiled_size;
NVTE_CHECK_NVRTC(nvrtcGetCUBINSize(program, &compiled_size));
compiled_code.resize(compiled_size);
NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data()));
}
// Cache compiled code
const auto key = get_kernel_cache_key(kernel_label, device_id);
kernel_cache_.insert({key, Kernel(mangled_name, std::move(compiled_code))});
kernel_cache_.at(key).get_function(device_id); // Make sure kernel is available on device
// Clean up
NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
}
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const {
const auto key = get_kernel_cache_key(kernel_label, device_id);
return kernel_cache_.count(key) > 0;
}
std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label,
int device_id) const {
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
}
} // namespace rtc
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
namespace transformer_engine {
namespace rtc {
/*! \brief Whether NVRTC support is enabled
*
* NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in
* the environment.
*/
bool is_enabled();
/*! \brief Wrapper class for a runtime-compiled CUDA kernel */
class Kernel {
public:
Kernel(std::string mangled_name, std::string compiled_code);
~Kernel();
Kernel(const Kernel&) = delete; // move-only
Kernel(Kernel&&) noexcept;
Kernel& operator=(Kernel) noexcept;
friend void swap(Kernel& first, Kernel& second) noexcept;
/*! \brief Launch CUDA kernel
*
* Loads the kernel into the device the first time the device is
* accessed.
*
* \param[in] device_id CUDA device
* \param[in] grid_dim Grid dimensions in blocks
* \param[in] block_dim Thread block dimensions
* \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
* bytes
* \param[in] stream CUDA stream
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(int device_id,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
void* arg_ptrs[] = { const_cast<void*>(static_cast<const void*>(&args))... };
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel,
get_function(device_id),
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
shared_mem_bytes,
static_cast<CUstream>(stream),
arg_ptrs,
nullptr);
}
/*! \brief CUDA function for given CUDA device
*
* Loads the kernel into the device the first time the device is
* accessed.
*/
CUfunction get_function(int device_id);
private:
/*! \brief Mangled function name */
std::string mangled_name_;
/*! \brief Compiled assembly, either in PTX or cubin format */
std::string compiled_code_;
/*! CUDA module for each CUDA device */
std::vector<CUmodule> modules_;
/*! CUDA function for each CUDA device */
std::vector<CUfunction> functions_;
/*! Flags for thread-safe kernel initialization */
std::unique_ptr<std::vector<std::once_flag>> init_flags_;
/*! \brief Uninitialized CUDA module */
static constexpr CUmodule null_module = static_cast<CUmodule>(nullptr);
/*! Uninitialized CUDA function */
static constexpr CUfunction null_function = static_cast<CUfunction>(nullptr);
};
/*! \brief Singleton class to manage runtime-compiled CUDA kernels */
class KernelManager {
public:
/*! \brief Get singleton instance */
static KernelManager& instance();
/*! \brief Compile CUDA kernel for current CUDA device
*
* The compiled kernel is cached and made available for launching.
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] kernel_name Kernel name within source code
* \param[in] code Kernel source code
* \param[in] filename Path to associate with source code,
* primarily for debugging
*/
void compile(const std::string &kernel_label,
const std::string &kernel_name,
const std::string &code,
const std::string &filename);
/*! \brief Whether CUDA kernel has been compiled for CUDA device
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] device_id CUDA device (default is current device)
* \return Whether kernel has been compiled
*/
bool is_compiled(const std::string &kernel_label,
int device_id = -1) const;
/*! \brief Launch CUDA kernel on current CUDA device
*
* Assumes the kernel has already been compiled.
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] grid_dim Grid dimensions in blocks
* \param[in] block_dim Thread block dimensions
* \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
* bytes
* \param[in] stream CUDA stream
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(const std::string &kernel_label,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0,
"Attempted to launch RTC kernel before compilation");
kernel_cache_.at(key).launch(device_id,
grid_dim,
block_dim,
shared_mem_bytes,
stream,
std::forward<ArgTs>(args)...);
}
private:
/*! \brief Compiled kernels */
std::unordered_map<std::string, Kernel> kernel_cache_;
/*! \brief Mutex for thread-safe compilation */
std::mutex lock_;
KernelManager() = default;
~KernelManager() = default;
KernelManager(const KernelManager&) = delete;
KernelManager& operator=(const KernelManager&) = delete;
/*! \brief Construct key for kernel cache
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] device_id CUDA device (default is current device)
*
* \return Key for kernel cache
*/
std::string get_kernel_cache_key(const std::string &kernel_label,
int device_id) const;
};
} // namespace rtc
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_
#include <regex> // NOLINT(*)
#include <string>
#include <type_traits>
namespace transformer_engine {
/*! \brief Convert to C-style or C++-style string */
template <typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
inline std::string to_string_like(const T &val) {
return std::to_string(val);
}
inline const std::string& to_string_like(const std::string& val) noexcept {
return val;
}
constexpr const char *to_string_like(const char *val) noexcept {
return val;
}
/*! \brief Convert arguments to strings and concatenate */
template <typename... Ts>
inline std::string concat_strings(const Ts &... args) {
std::string str;
str.reserve(1024); // Assume strings are <1 KB
(..., (str += to_string_like(args)));
return str;
}
/*! \brief Substitute regex occurances in string
*
* This is a convenience wrapper around std::regex_replace.
*/
template <typename T>
inline std::string regex_replace(const std::string &str,
const std::string &pattern,
const T &replacement) {
return std::regex_replace(str,
std::regex(pattern),
to_string_like(replacement));
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
static constexpr char @STRING_NAME@[]
= R"__STRING_DELIM__(@STRING@)__STRING_DELIM__";
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstdint>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <sstream>
#include <string>
#include "../common.h"
#include "../util/system.h"
namespace transformer_engine {
namespace {
template <typename T>
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type
getenv_helper(const std::string &variable, const T &default_value) {
// Implementation for numeric types
const char *env = std::getenv(variable.c_str());
if (env == nullptr || env[0] == '\0') {
return default_value;
}
T value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
template <typename T>
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type
getenv_helper(const std::string &variable, const T &default_value) {
// Implementation for string-like types
const char *env = std::getenv(variable.c_str());
if (env == nullptr || env[0] == '\0') {
return default_value;
} else {
return env;
}
}
} // namespace
#define NVTE_INSTANTIATE_GETENV(T, default_value) \
template <> T getenv<T>(const std::string &variable, \
const T &default_value_) { \
return getenv_helper<T>(variable, default_value_); \
} \
template <> T getenv<T>(const std::string &variable) { \
return getenv_helper<T>(variable, default_value); \
}
NVTE_INSTANTIATE_GETENV(bool, false);
NVTE_INSTANTIATE_GETENV(float, 0.f);
NVTE_INSTANTIATE_GETENV(double, 0.);
NVTE_INSTANTIATE_GETENV(int8_t, 0);
NVTE_INSTANTIATE_GETENV(int16_t, 0);
NVTE_INSTANTIATE_GETENV(int32_t, 0);
NVTE_INSTANTIATE_GETENV(int64_t, 0);
NVTE_INSTANTIATE_GETENV(uint8_t, 0);
NVTE_INSTANTIATE_GETENV(uint16_t, 0);
NVTE_INSTANTIATE_GETENV(uint32_t, 0);
NVTE_INSTANTIATE_GETENV(uint64_t, 0);
NVTE_INSTANTIATE_GETENV(std::string, std::string());
NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path());
bool file_exists(const std::string &path) {
return static_cast<bool>(std::ifstream(path.c_str()));
}
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment