Unverified Commit cbb96f2b authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Export only necessary symbols from libtransformer_engine.so (#1511)



* Expose only required symbols from libtransformer_engine.so during linking for pytorch
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Augment libtransformer_engine.version for jax compatibility
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Augment the libtransformer_engine.version to ensure compatibility with CPP tests
Remove getenv from the .version file
Combine system.cpp and system.h
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Nit: Remove commented code for not including common.h
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Replace explicit getenv instantiations with a helper template
Use filesystem calls in file_exists()
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert comment to falsy instead of false
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 90d5d45d
......@@ -80,7 +80,6 @@ list(APPEND transformer_engine_SOURCES
util/cuda_driver.cpp
util/cuda_runtime.cpp
util/rtc.cpp
util/system.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
......
{
global: *nvte*; *transformer_engine*;
global:
extern "C++" {
nvte_*;
transformer_engine::cuda::sm_count*;
transformer_engine::cuda::sm_arch*;
transformer_engine::cuda::supports_multicast*;
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::typeToSize*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
};
local: *;
};
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/system.h"
#include <cstdint>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <sstream>
#include <string>
#include "../common.h"
namespace transformer_engine {
namespace {
template <typename T>
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for numeric types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
return default_value;
}
T value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
template <typename T>
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for string-like types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
return default_value;
} else {
return env;
}
}
} // namespace
#define NVTE_INSTANTIATE_GETENV(T, default_value) \
template <> \
T getenv<T>(const char *variable, const T &default_value_) { \
return getenv_helper<T>(variable, default_value_); \
} \
template <> \
T getenv<T>(const char *variable) { \
return getenv_helper<T>(variable, default_value); \
}
NVTE_INSTANTIATE_GETENV(bool, false);
NVTE_INSTANTIATE_GETENV(float, 0.f);
NVTE_INSTANTIATE_GETENV(double, 0.);
NVTE_INSTANTIATE_GETENV(int8_t, 0);
NVTE_INSTANTIATE_GETENV(int16_t, 0);
NVTE_INSTANTIATE_GETENV(int32_t, 0);
NVTE_INSTANTIATE_GETENV(int64_t, 0);
NVTE_INSTANTIATE_GETENV(uint8_t, 0);
NVTE_INSTANTIATE_GETENV(uint16_t, 0);
NVTE_INSTANTIATE_GETENV(uint32_t, 0);
NVTE_INSTANTIATE_GETENV(uint64_t, 0);
NVTE_INSTANTIATE_GETENV(std::string, std::string());
NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path());
bool file_exists(const std::string &path) { return static_cast<bool>(std::ifstream(path.c_str())); }
} // namespace transformer_engine
......@@ -7,25 +7,96 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_
#include <cstdint>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <sstream>
#include <string>
#include "logging.h"
namespace transformer_engine {
namespace detail {
/*! \brief Template specialization to get the env var for numeric data types */
template <typename T>
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for numeric types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
return default_value;
}
T value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
/*! \brief Template specialization to get the env var for string-like data types */
template <typename T>
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for string-like types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
return default_value;
} else {
return env;
}
}
/*! \brief Template specialization to get the default values for different
* numeric data types
*/
template <typename T>
inline T getenv_default_value() {
return 0;
}
/*! \brief Template specialization to get the default values for bool */
template <>
inline bool getenv_default_value<bool>() {
return false;
}
/*! \brief Template specialization to get the default values for string */
template <>
inline std::string getenv_default_value<std::string>() {
return std::string();
}
/*! \brief Template specialization to get the default values for filesystem
* path data type */
template <>
inline std::filesystem::path getenv_default_value<std::filesystem::path>() {
return std::filesystem::path();
}
} // namespace detail
/*! \brief Get environment variable and convert to type
*
* If the environment variable is unset or empty, a falsy value is
* returned.
*/
*/
template <typename T = std::string>
T getenv(const char *variable);
inline T getenv(const char *variable) {
return detail::getenv_helper<T>(variable, detail::getenv_default_value<T>());
}
/*! \brief Get environment variable and convert to type */
template <typename T = std::string>
T getenv(const char *variable, const T &default_value);
inline T getenv(const char *variable, const T &default_value) {
return detail::getenv_helper<T>(variable, default_value);
}
/*! \brief Check if a file exists and can be read */
bool file_exists(const std::string &path);
inline bool file_exists(const std::string &path) {
return std::filesystem::exists(path) && std::filesystem::is_regular_file(path);
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_
......@@ -540,7 +540,8 @@ static void FusedAttnBackwardImpl(
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
......@@ -558,8 +559,9 @@ static void FusedAttnBackwardImpl(
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -581,9 +583,9 @@ static void FusedAttnBackwardImpl(
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......
......@@ -26,5 +26,13 @@ struct Shape {
std::vector<size_t> MakeShapeVector(NVTEShape shape);
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
ret *= elem;
}
return ret;
}
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment