Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string> #include <string>
...@@ -14,10 +18,6 @@ ...@@ -14,10 +18,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include "../common.h" #include "../common.h"
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
...@@ -38,10 +38,10 @@ class Kernel { ...@@ -38,10 +38,10 @@ class Kernel {
public: public:
Kernel(std::string mangled_name, std::string compiled_code); Kernel(std::string mangled_name, std::string compiled_code);
~Kernel(); ~Kernel();
Kernel(const Kernel&) = delete; // move-only Kernel(const Kernel &) = delete; // move-only
Kernel(Kernel&&) noexcept; Kernel(Kernel &&) noexcept;
Kernel& operator=(Kernel) noexcept; Kernel &operator=(Kernel) noexcept;
friend void swap(Kernel& first, Kernel& second) noexcept; friend void swap(Kernel &first, Kernel &second) noexcept;
/*! \brief Launch CUDA kernel /*! \brief Launch CUDA kernel
* *
...@@ -57,25 +57,12 @@ class Kernel { ...@@ -57,25 +57,12 @@ class Kernel {
* \param[in] args Kernel arguments * \param[in] args Kernel arguments
*/ */
template <typename... ArgTs> template <typename... ArgTs>
void launch(int device_id, void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
const dim3 grid_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
const dim3 block_dim, void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
unsigned int shared_mem_bytes, NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
cudaStream_t stream, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
ArgTs &&... args) { static_cast<CUstream>(stream), arg_ptrs, nullptr);
void* arg_ptrs[] = { const_cast<void*>(static_cast<const void*>(&args))... };
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel,
get_function(device_id),
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
shared_mem_bytes,
static_cast<CUstream>(stream),
arg_ptrs,
nullptr);
} }
/*! \brief CUDA function for given CUDA device /*! \brief CUDA function for given CUDA device
...@@ -114,7 +101,7 @@ class Kernel { ...@@ -114,7 +101,7 @@ class Kernel {
class KernelManager { class KernelManager {
public: public:
/*! \brief Get singleton instance */ /*! \brief Get singleton instance */
static KernelManager& instance(); static KernelManager &instance();
/*! \brief Compile CUDA kernel for current CUDA device /*! \brief Compile CUDA kernel for current CUDA device
* *
...@@ -126,10 +113,8 @@ class KernelManager { ...@@ -126,10 +113,8 @@ class KernelManager {
* \param[in] filename Path to associate with source code, * \param[in] filename Path to associate with source code,
* primarily for debugging * primarily for debugging
*/ */
void compile(const std::string &kernel_label, void compile(const std::string &kernel_label, const std::string &kernel_name,
const std::string &kernel_name, const std::string &code, const std::string &filename);
const std::string &code,
const std::string &filename);
/*! \brief Whether CUDA kernel has been compiled for CUDA device /*! \brief Whether CUDA kernel has been compiled for CUDA device
* *
...@@ -138,8 +123,7 @@ class KernelManager { ...@@ -138,8 +123,7 @@ class KernelManager {
* \return Whether kernel has been compiled * \return Whether kernel has been compiled
*/ */
bool is_compiled(const std::string &kernel_label, bool is_compiled(const std::string &kernel_label, int device_id = -1) const;
int device_id = -1) const;
/*! \brief Launch CUDA kernel on current CUDA device /*! \brief Launch CUDA kernel on current CUDA device
* *
...@@ -154,21 +138,12 @@ class KernelManager { ...@@ -154,21 +138,12 @@ class KernelManager {
* \param[in] args Kernel arguments * \param[in] args Kernel arguments
*/ */
template <typename... ArgTs> template <typename... ArgTs>
void launch(const std::string &kernel_label, void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim,
const dim3 grid_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0, NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation");
"Attempted to launch RTC kernel before compilation"); kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream,
kernel_cache_.at(key).launch(device_id,
grid_dim,
block_dim,
shared_mem_bytes,
stream,
std::forward<ArgTs>(args)...); std::forward<ArgTs>(args)...);
} }
...@@ -189,8 +164,8 @@ class KernelManager { ...@@ -189,8 +164,8 @@ class KernelManager {
KernelManager() = default; KernelManager() = default;
~KernelManager() = default; ~KernelManager() = default;
KernelManager(const KernelManager&) = delete; KernelManager(const KernelManager &) = delete;
KernelManager& operator=(const KernelManager&) = delete; KernelManager &operator=(const KernelManager &) = delete;
/*! \brief Construct key for kernel cache /*! \brief Construct key for kernel cache
* *
...@@ -199,8 +174,7 @@ class KernelManager { ...@@ -199,8 +174,7 @@ class KernelManager {
* *
* \return Key for kernel cache * \return Key for kernel cache
*/ */
std::string get_kernel_cache_key(const std::string &kernel_label, std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const;
int device_id) const;
}; };
} // namespace rtc } // namespace rtc
......
...@@ -14,23 +14,18 @@ ...@@ -14,23 +14,18 @@
namespace transformer_engine { namespace transformer_engine {
/*! \brief Convert to C-style or C++-style string */ /*! \brief Convert to C-style or C++-style string */
template <typename T, template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
inline std::string to_string_like(const T &val) { inline std::string to_string_like(const T &val) {
return std::to_string(val); return std::to_string(val);
} }
inline const std::string& to_string_like(const std::string& val) noexcept { inline const std::string &to_string_like(const std::string &val) noexcept { return val; }
return val;
}
constexpr const char *to_string_like(const char *val) noexcept { constexpr const char *to_string_like(const char *val) noexcept { return val; }
return val;
}
/*! \brief Convert arguments to strings and concatenate */ /*! \brief Convert arguments to strings and concatenate */
template <typename... Ts> template <typename... Ts>
inline std::string concat_strings(const Ts &... args) { inline std::string concat_strings(const Ts &...args) {
std::string str; std::string str;
str.reserve(1024); // Assume strings are <1 KB str.reserve(1024); // Assume strings are <1 KB
(..., (str += to_string_like(args))); (..., (str += to_string_like(args)));
...@@ -42,12 +37,9 @@ inline std::string concat_strings(const Ts &... args) { ...@@ -42,12 +37,9 @@ inline std::string concat_strings(const Ts &... args) {
* This is a convenience wrapper around std::regex_replace. * This is a convenience wrapper around std::regex_replace.
*/ */
template <typename T> template <typename T>
inline std::string regex_replace(const std::string &str, inline std::string regex_replace(const std::string &str, const std::string &pattern,
const std::string &pattern,
const T &replacement) { const T &replacement) {
return std::regex_replace(str, return std::regex_replace(str, std::regex(pattern), to_string_like(replacement));
std::regex(pattern),
to_string_like(replacement));
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/system.h"
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <filesystem> #include <filesystem>
...@@ -12,15 +14,14 @@ ...@@ -12,15 +14,14 @@
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../util/system.h"
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
template <typename T> template <typename T>
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type getenv_helper(
getenv_helper(const char *variable, const T &default_value) { const char *variable, const T &default_value) {
// Implementation for numeric types // Implementation for numeric types
const char *env = std::getenv(variable); const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') { if (env == nullptr || env[0] == '\0') {
...@@ -34,8 +35,8 @@ getenv_helper(const char *variable, const T &default_value) { ...@@ -34,8 +35,8 @@ getenv_helper(const char *variable, const T &default_value) {
} }
template <typename T> template <typename T>
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type getenv_helper(
getenv_helper(const char *variable, const T &default_value) { const char *variable, const T &default_value) {
// Implementation for string-like types // Implementation for string-like types
const char *env = std::getenv(variable); const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') { if (env == nullptr || env[0] == '\0') {
...@@ -47,13 +48,14 @@ getenv_helper(const char *variable, const T &default_value) { ...@@ -47,13 +48,14 @@ getenv_helper(const char *variable, const T &default_value) {
} // namespace } // namespace
#define NVTE_INSTANTIATE_GETENV(T, default_value) \ #define NVTE_INSTANTIATE_GETENV(T, default_value) \
template <> T getenv<T>(const char *variable, \ template <> \
const T &default_value_) { \ T getenv<T>(const char *variable, const T &default_value_) { \
return getenv_helper<T>(variable, default_value_); \ return getenv_helper<T>(variable, default_value_); \
} \ } \
template <> T getenv<T>(const char *variable) { \ template <> \
return getenv_helper<T>(variable, default_value); \ T getenv<T>(const char *variable) { \
return getenv_helper<T>(variable, default_value); \
} }
NVTE_INSTANTIATE_GETENV(bool, false); NVTE_INSTANTIATE_GETENV(bool, false);
NVTE_INSTANTIATE_GETENV(float, 0.f); NVTE_INSTANTIATE_GETENV(float, 0.f);
...@@ -69,8 +71,6 @@ NVTE_INSTANTIATE_GETENV(uint64_t, 0); ...@@ -69,8 +71,6 @@ NVTE_INSTANTIATE_GETENV(uint64_t, 0);
NVTE_INSTANTIATE_GETENV(std::string, std::string()); NVTE_INSTANTIATE_GETENV(std::string, std::string());
NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path()); NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path());
bool file_exists(const std::string &path) { bool file_exists(const std::string &path) { return static_cast<bool>(std::ifstream(path.c_str())); }
return static_cast<bool>(std::ifstream(path.c_str()));
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
#include <type_traits> #include <type_traits>
#include "../common.h" #include "../common.h"
#include "../utils.cuh" #include "../utils.cuh"
...@@ -30,15 +31,13 @@ class VectorizedStorage { ...@@ -30,15 +31,13 @@ class VectorizedStorage {
} scratch_; } scratch_;
inline __device__ VectorizedStorage() {} inline __device__ VectorizedStorage() {}
inline __device__ VectorizedStorage(const VectorizedStorage<DType, n>& y2) { inline __device__ VectorizedStorage(const VectorizedStorage<DType, n> &y2) {
scratch_.aligned = y2.scratch_.aligned; scratch_.aligned = y2.scratch_.aligned;
}
inline __device__ VectorizedStorage(const LType &y2) {
scratch_.aligned = y2;
} }
inline __device__ VectorizedStorage<DType, n>& operator+=( inline __device__ VectorizedStorage(const LType &y2) { scratch_.aligned = y2; }
const VectorizedStorage<DType, n>& rhs) { inline __device__ VectorizedStorage<DType, n> &operator+=(
#pragma unroll const VectorizedStorage<DType, n> &rhs) {
#pragma unroll
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
} }
...@@ -58,7 +57,6 @@ struct select_const<const DType, LType> { ...@@ -58,7 +57,6 @@ struct select_const<const DType, LType> {
using type = const LType; using type = const LType;
}; };
/* \brief Helper class that enables accessing multiple values of type DType /* \brief Helper class that enables accessing multiple values of type DType
as 1 value of type LType. Additional aligned template argument as 1 value of type LType. Additional aligned template argument
allows performance optimizations if the pointer and the size of allows performance optimizations if the pointer and the size of
...@@ -67,44 +65,37 @@ struct select_const<const DType, LType> { ...@@ -67,44 +65,37 @@ struct select_const<const DType, LType> {
template <typename DType, int nvec, bool aligned = false> template <typename DType, int nvec, bool aligned = false>
class VectorizedAccessor { class VectorizedAccessor {
public: public:
using StorageType = VectorizedStorage<typename std::remove_const<DType>::type, using StorageType = VectorizedStorage<typename std::remove_const<DType>::type, nvec>;
nvec>;
using LType = typename select_const<DType, typename StorageType::LType>::type; using LType = typename select_const<DType, typename StorageType::LType>::type;
StorageType storage_; StorageType storage_;
LType* aligned_ptr_; LType *aligned_ptr_;
DType* unaligned_ptr_; DType *unaligned_ptr_;
int alignment_; int alignment_;
size_t n_elems_; size_t n_elems_;
inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) { inline __device__ VectorizedAccessor(DType *const ptr, const size_t size) {
unaligned_ptr_ = ptr; unaligned_ptr_ = ptr;
if (aligned) { if (aligned) {
alignment_ = 0; alignment_ = 0;
aligned_ptr_ = reinterpret_cast<LType*>(ptr); aligned_ptr_ = reinterpret_cast<LType *>(ptr);
n_elems_ = (size + nvec - 1) / nvec; n_elems_ = (size + nvec - 1) / nvec;
} else { } else {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr); size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_); aligned_ptr_ = reinterpret_cast<LType *>(ptr - alignment_);
n_elems_ = (size + alignment_ + nvec - 1) / nvec; n_elems_ = (size + alignment_ + nvec - 1) / nvec;
} }
} }
/* \brief Alignment of the input pointer in elements. */ /* \brief Alignment of the input pointer in elements. */
inline __device__ int alignment() const { inline __device__ int alignment() const { return alignment_; }
return alignment_;
}
/* \brief Access to separate elements. */ /* \brief Access to separate elements. */
inline __device__ DType* separate() { inline __device__ DType *separate() { return storage_.scratch_.separate; }
return storage_.scratch_.separate;
}
/* \brief Number of aligned elements that span the entire input tensor. */ /* \brief Number of aligned elements that span the entire input tensor. */
inline __device__ size_t num_aligned_elements() const { inline __device__ size_t num_aligned_elements() const { return n_elems_; }
return n_elems_;
}
/* \brief Load values from the input. /* \brief Load values from the input.
\param id Aligned index of the element. \param id Aligned index of the element.
...@@ -119,7 +110,7 @@ class VectorizedAccessor { ...@@ -119,7 +110,7 @@ class VectorizedAccessor {
} else { } else {
#pragma unroll #pragma unroll
for (int j = 0; j < nvec; ++j) { for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j; DType *ptr = reinterpret_cast<DType *>(&(aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) && if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) { reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
storage_.scratch_.separate[j] = *ptr; storage_.scratch_.separate[j] = *ptr;
...@@ -136,18 +127,16 @@ class VectorizedAccessor { ...@@ -136,18 +127,16 @@ class VectorizedAccessor {
template <typename DType, int nvec, bool aligned = false> template <typename DType, int nvec, bool aligned = false>
class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> { class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
public: public:
inline __device__ VectorizedLoader(const DType* ptr, const size_t N) : inline __device__ VectorizedLoader(const DType *ptr, const size_t N)
VectorizedAccessor<const DType, nvec, aligned>(ptr, N) { : VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {}
}
}; };
/* \brief Class used for vectorized writable access. */ /* \brief Class used for vectorized writable access. */
template <typename DType, int nvec, bool aligned = false> template <typename DType, int nvec, bool aligned = false>
class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> { class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
public: public:
inline __device__ VectorizedStorer(DType* ptr, const size_t N) : inline __device__ VectorizedStorer(DType *ptr, const size_t N)
VectorizedAccessor<DType, nvec, aligned>(ptr, N) { : VectorizedAccessor<DType, nvec, aligned>(ptr, N) {}
}
/* \brief Store values to the output. /* \brief Store values to the output.
\param id Aligned index of the element. \param id Aligned index of the element.
...@@ -162,7 +151,7 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> { ...@@ -162,7 +151,7 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
} else { } else {
#pragma unroll #pragma unroll
for (int j = 0; j < nvec; ++j) { for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j; DType *ptr = reinterpret_cast<DType *>(&(this->aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) && if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) { reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
*ptr = this->storage_.scratch_.separate[j]; *ptr = this->storage_.scratch_.separate[j];
...@@ -175,34 +164,24 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> { ...@@ -175,34 +164,24 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
constexpr int unary_kernel_threads = 512; constexpr int unary_kernel_threads = 512;
template <int nvec, bool aligned, template <int nvec, bool aligned, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType>
typename Param, __launch_bounds__(unary_kernel_threads) __global__
ComputeType (*OP)(ComputeType, const Param&), void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
typename InputType, ComputeType *amax, Param p, const size_t N,
typename OutputType> const size_t num_aligned_elements) {
__launch_bounds__(unary_kernel_threads)
__global__ void unary_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements; const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N); loader.load(tid, N);
#pragma unroll #pragma unroll
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
...@@ -224,43 +203,32 @@ __global__ void unary_kernel(const InputType *input, ...@@ -224,43 +203,32 @@ __global__ void unary_kernel(const InputType *input,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) { if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
} }
} }
template <int nvec, bool aligned, template <int nvec, bool aligned, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename InputTypeGrad,
typename Param,
ComputeType (*OP)(ComputeType, const Param&),
typename InputType,
typename InputTypeGrad,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __launch_bounds__(unary_kernel_threads) __global__
__global__ void unary_grad_kernel(const InputTypeGrad *grad, void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output,
const InputType *input, const ComputeType *scale, ComputeType *amax, Param p, const size_t N,
OutputType *output, const size_t num_aligned_elements) {
const ComputeType *scale,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N); VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements; const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N); loader.load(tid, N);
grad_loader.load(tid, N); grad_loader.load(tid, N);
#pragma unroll #pragma unroll
...@@ -284,25 +252,25 @@ __global__ void unary_grad_kernel(const InputTypeGrad *grad, ...@@ -284,25 +252,25 @@ __global__ void unary_grad_kernel(const InputTypeGrad *grad,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) { if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
} }
} }
namespace { namespace {
inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, const int nvec,
const int nvec, const int size) { const int size) {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr); size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
int alignment = (ptr_as_number % (nvec * size)) / size; int alignment = (ptr_as_number % (nvec * size)) / size;
return DIVUP(lead_dim + alignment, static_cast<size_t>(nvec)); return DIVUP(lead_dim + alignment, static_cast<size_t>(nvec));
} }
enum class Alignment { enum class Alignment {
SAME_ALIGNED, // All tensors aligned SAME_ALIGNED, // All tensors aligned
SAME_UNALIGNED, // All tensors have the same misalignment SAME_UNALIGNED, // All tensors have the same misalignment
DIFFERENT // Tensors have different alignment DIFFERENT // Tensors have different alignment
}; };
inline int CalcAlignment(const void *ptr, const int size) { inline int CalcAlignment(const void *ptr, const int size) {
...@@ -317,10 +285,7 @@ inline int CalcAlignment(const void *ptr, const int size) { ...@@ -317,10 +285,7 @@ inline int CalcAlignment(const void *ptr, const int size) {
\param ptrs Inputs and Outputs to the operator. \param ptrs Inputs and Outputs to the operator.
*/ */
template <typename... T> template <typename... T>
Alignment CheckAlignment(const size_t lead_dim, Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) {
const int nvec,
const T... ptrs
) {
std::vector<int> alignments; std::vector<int> alignments;
alignments.reserve(sizeof...(T)); alignments.reserve(sizeof...(T));
...@@ -328,13 +293,12 @@ Alignment CheckAlignment(const size_t lead_dim, ...@@ -328,13 +293,12 @@ Alignment CheckAlignment(const size_t lead_dim,
(..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec))); (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec)));
bool all_same = std::all_of(alignments.cbegin(), alignments.cend(), bool all_same = std::all_of(alignments.cbegin(), alignments.cend(),
[alignments](int val) {return val == alignments.front();}); [alignments](int val) { return val == alignments.front(); });
if (!all_same) { if (!all_same) {
return Alignment::DIFFERENT; return Alignment::DIFFERENT;
} }
if (alignments.front() == 0 && if (alignments.front() == 0 && lead_dim % nvec == 0) {
lead_dim % nvec == 0) {
// all alignment are 0 // all alignment are 0
return Alignment::SAME_ALIGNED; return Alignment::SAME_ALIGNED;
} else { } else {
...@@ -344,22 +308,15 @@ Alignment CheckAlignment(const size_t lead_dim, ...@@ -344,22 +308,15 @@ Alignment CheckAlignment(const size_t lead_dim,
} // namespace } // namespace
template <int nvec, typename Param, template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType,
fp32 (*OP)(const fp32, const Param&),
typename InputType,
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
OutputType *output, fp32 *amax, const size_t N, const Param params,
const fp32 *scale,
fp32 *amax,
const size_t N,
const Param params,
cudaStream_t stream) { cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output); auto align = CheckAlignment(N, nvec, input, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads; constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads); size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535; constexpr size_t max_blocks = 65535;
...@@ -376,32 +333,23 @@ void VectorizedUnaryKernelLauncher(const InputType *input, ...@@ -376,32 +333,23 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<1, true, fp32, Param, OP>
input, output, scale, amax, params, N, N); <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, params, N, N);
break; break;
} }
} }
} }
} }
template <int nvec, typename Param, template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename InputType,
fp32 (*OP)(fp32, const Param&), typename InputTypeGrad, typename OutputType>
typename InputType, void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
typename InputTypeGrad, OutputType *output, const fp32 *scale, fp32 *amax,
typename OutputType> const size_t N, const Param params, cudaStream_t stream) {
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad,
const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *amax,
const size_t N,
const Param params,
cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output); auto align = CheckAlignment(N, nvec, input, grad, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads; constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads); size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535; constexpr size_t max_blocks = 65535;
...@@ -418,33 +366,23 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, ...@@ -418,33 +366,23 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad,
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_grad_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_grad_kernel<1, true, fp32, Param, OP>
grad, input, output, scale, amax, params, N, N); <<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, params, N, N);
break; break;
} }
} }
} }
} }
template <int nvec, bool aligned, template <int nvec, bool aligned, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __launch_bounds__(unary_kernel_threads) __global__
__global__ void gated_act_kernel(const InputType *input, void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
OutputType *output, ComputeType *amax, const size_t m, const size_t n, const Param p,
const ComputeType *scale, const size_t num_aligned_elements) {
ComputeType *amax,
const size_t m,
const size_t n,
const Param p,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements; const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements; const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n); VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
...@@ -453,7 +391,7 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -453,7 +391,7 @@ __global__ void gated_act_kernel(const InputType *input,
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -478,26 +416,18 @@ __global__ void gated_act_kernel(const InputType *input, ...@@ -478,26 +416,18 @@ __global__ void gated_act_kernel(const InputType *input,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) { if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
} }
} }
} }
template <int nvec, template <int nvec, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType,
typename OutputType> typename OutputType>
void GatedActivationKernelLauncher(const InputType *input, void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
OutputType *output, fp32 *amax, const size_t m, const size_t n, const Param &p,
const fp32 *scale,
fp32 *amax,
const size_t m,
const size_t n,
const Param &p,
cudaStream_t stream) { cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
...@@ -509,44 +439,34 @@ void GatedActivationKernelLauncher(const InputType *input, ...@@ -509,44 +439,34 @@ void GatedActivationKernelLauncher(const InputType *input,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Param, Activation> gated_act_kernel<nvec, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>( <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
input, output, scale, amax, m, n, p, num_aligned_elements); num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Param, Activation> gated_act_kernel<nvec, false, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>( <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
input, output, scale, amax, m, n, p, num_aligned_elements); num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Param, Activation> gated_act_kernel<1, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>( <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, n);
input, output, scale, amax, m, n, p, n);
break; break;
} }
} }
} }
} }
template <int nvec, bool aligned, template <int nvec, bool aligned, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*Activation)(const ComputeType, const Param &),
typename Param, ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType,
ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __launch_bounds__(unary_kernel_threads) __global__
__global__ void dgated_act_kernel(const InputType *grad, void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output,
const InputType *input, const size_t m, const size_t n, const Param p,
OutputType *output, const size_t num_aligned_elements) {
const size_t m,
const size_t n,
const Param p,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements; const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements; const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> grad_loader(grad + id_y * n, n); VectorizedLoader<InputType, nvec, aligned> grad_loader(grad + id_y * n, n);
...@@ -576,23 +496,15 @@ __global__ void dgated_act_kernel(const InputType *grad, ...@@ -576,23 +496,15 @@ __global__ void dgated_act_kernel(const InputType *grad,
} }
} }
template <int nvec, template <int nvec, typename ComputeType, typename Param,
typename ComputeType, ComputeType (*Activation)(const ComputeType, const Param &),
typename Param, ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType,
ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType,
typename OutputType> typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad, void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input,
const InputType *input, OutputType *output, const size_t m, const size_t n,
OutputType *output, const Param &p, cudaStream_t stream) {
const size_t m,
const size_t n,
const Param &p,
cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType));
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads; constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads); size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535; constexpr size_t max_blocks = 65535;
...@@ -601,16 +513,18 @@ void DGatedActivationKernelLauncher(const InputType *grad, ...@@ -601,16 +513,18 @@ void DGatedActivationKernelLauncher(const InputType *grad,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p,
num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p,
num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n); <<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n);
break; break;
} }
} }
......
...@@ -31,47 +31,45 @@ constexpr uint32_t THREADS_PER_WARP = 32; ...@@ -31,47 +31,45 @@ constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*) inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b) { // NOLINT(*) inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*)
a.x += b.x; a.x += b.x;
a.y += b.y; a.y += b.y;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template <typename T>
struct Sum { struct Sum {
inline __device__ Sum() {} inline __device__ Sum() {}
inline __device__ T operator()(const T &a, const T &b) const { inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
return a + b;
}
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template <typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx) { inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx); return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
} }
template<> template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx) { inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
} }
template<typename T> template <typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx) { inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx); return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
} }
template<> template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx) { inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -81,533 +79,517 @@ namespace transformer_engine { ...@@ -81,533 +79,517 @@ namespace transformer_engine {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 { struct uint16 {
uint4 u; uint4 u;
uint4 v; uint4 v;
uint4 s; uint4 s;
uint4 t; uint4 t;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 { struct uint8 {
uint4 u; uint4 u;
uint4 v; uint4 v;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> template <int BYTES>
struct BytesToType {}; struct BytesToType {};
template<> template <>
struct BytesToType<64> { struct BytesToType<64> {
using Type = uint16; using Type = uint16;
static_assert(sizeof(Type) == 64); static_assert(sizeof(Type) == 64);
}; };
template<> template <>
struct BytesToType<32> { struct BytesToType<32> {
using Type = uint8; using Type = uint8;
static_assert(sizeof(Type) == 32); static_assert(sizeof(Type) == 32);
}; };
template<> template <>
struct BytesToType<16> { struct BytesToType<16> {
using Type = uint4; using Type = uint4;
static_assert(sizeof(Type) == 16); static_assert(sizeof(Type) == 16);
}; };
template<> template <>
struct BytesToType<8> { struct BytesToType<8> {
using Type = uint64_t; using Type = uint64_t;
static_assert(sizeof(Type) == 8); static_assert(sizeof(Type) == 8);
}; };
template<> template <>
struct BytesToType<4> { struct BytesToType<4> {
using Type = uint32_t; using Type = uint32_t;
static_assert(sizeof(Type) == 4); static_assert(sizeof(Type) == 4);
}; };
template<> template <>
struct BytesToType<2> { struct BytesToType<2> {
using Type = uint16_t; using Type = uint16_t;
static_assert(sizeof(Type) == 2); static_assert(sizeof(Type) == 2);
}; };
template<> template <>
struct BytesToType<1> { struct BytesToType<1> {
using Type = uint8_t; using Type = uint8_t;
static_assert(sizeof(Type) == 1); static_assert(sizeof(Type) == 1);
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template <typename T>
struct TypeToVec2 {}; struct TypeToVec2 {};
template<> template <>
struct TypeToVec2<float> { struct TypeToVec2<float> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct TypeToVec2<half> { struct TypeToVec2<half> {
using Type = half2; using Type = half2;
}; };
template<> template <>
struct TypeToVec2<nv_bfloat16> { struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162; using Type = nv_bfloat162;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename IType, typename IType2, typename OType, typename CType> template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam { struct CTDBiasDActParam {
using InputType = IType; using InputType = IType;
using InputType2 = IType2; using InputType2 = IType2;
using OutputType = OType; using OutputType = OType;
using ComputeType = CType; using ComputeType = CType;
const IType *input; const IType *input;
const IType2 *act_input; const IType2 *act_input;
OType *output_c; OType *output_c;
OType *output_t; OType *output_t;
const CType *scale_ptr; const CType *scale_ptr;
CType *amax; CType *amax;
CType *scale_inv; CType *scale_inv;
CType *workspace; CType *workspace;
CType *warp_scales_inv; CType *warp_scales_inv;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX> template <int INDEX>
struct Get { struct Get {
template<typename T, typename R> template <typename T, typename R>
static inline __device__ R of(const T &vec); static inline __device__ R of(const T &vec);
}; };
template<> template <>
template<typename T, typename R> template <typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) { inline __device__ R Get<0>::of(const T &vec) {
return vec.x; return vec.x;
} }
template<> template <>
template<typename T, typename R> template <typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) { inline __device__ R Get<1>::of(const T &vec) {
return vec.y; return vec.y;
} }
template<> template <>
template<typename T, typename R> template <typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) { inline __device__ R Get<2>::of(const T &vec) {
return vec.z; return vec.z;
} }
template<> template <>
template<typename T, typename R> template <typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) { inline __device__ R Get<3>::of(const T &vec) {
return vec.w; return vec.w;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst> template <typename Src, typename Dst>
struct Converter{ struct Converter {
static inline __device__ Dst convert(const Src &from) { static inline __device__ Dst convert(const Src &from) { return Dst(from); }
return Dst(from);
}
}; };
template<> template <>
struct Converter<float2, half2>{ struct Converter<float2, half2> {
static inline __device__ half2 convert(const float2 &x) { static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
return __float22half2_rn(x);
}
}; };
template<> template <>
struct Converter<float2, nv_bfloat162>{ struct Converter<float2, nv_bfloat162> {
static inline __device__ nv_bfloat162 convert(const float2 &x) { static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x); return __float22bfloat162_rn(x);
#else #else
union { union {
nv_bfloat162 raw; nv_bfloat162 raw;
nv_bfloat16 elt[2]; nv_bfloat16 elt[2];
} tmp; } tmp;
tmp.elt[0] = __float2bfloat16_rn(x.x); tmp.elt[0] = __float2bfloat16_rn(x.x);
tmp.elt[1] = __float2bfloat16_rn(x.y); tmp.elt[1] = __float2bfloat16_rn(x.y);
return tmp.raw; return tmp.raw;
#endif #endif
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template <typename T>
struct Zeros{ struct Zeros {
static inline __device__ T get() { static inline __device__ T get() { return T(0.f); }
return T(0.f);
}
}; };
template<> template <>
struct Zeros<float2>{ struct Zeros<float2> {
static inline __device__ float2 get() { static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
return make_float2(0.f, 0.f);
}
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT> template <typename Elt_type, uint32_t NUM_ELT>
struct Vec { struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) }; enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using type = Elt_type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data; using Vec_type = typename BytesToType<BYTES>::Type;
using type = Elt_type;
template<typename S> using Alias_type = union {
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*) Vec_type vec;
#pragma unroll Elt_type elt[NUM_ELT];
for ( int it = 0; it < NUM_ELT; it++ ) { };
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op> Alias_type data;
inline __device__ void assign(const Op &op) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
// Pointer is cast to vector type template <typename S>
inline __device__ void load_from(const void *base_ptr, size_t idx = 0) { inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx]; #pragma unroll
} for (int it = 0; it < NUM_ELT; it++) {
other.data.elt[it] = S(this->data.elt[it]);
// Pointer is cast to vector type
inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
// Pointer is cast to element type. Loads min(count, NUM_ELT)
// elements and any remaining elements are set to zero.
inline __device__ void load_from_elts(const void *base_ptr,
size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = (it < count
? elt_ptr[it]
: Elt_type(0.f));
}
} else {
this->load_from(elt_ptr);
}
} }
}
// Pointer is cast to element type. Stores min(count, NUM_ELT) template <typename Op>
// elements. inline __device__ void assign(const Op &op) {
inline __device__ void store_to_elts(void *base_ptr, #pragma unroll
size_t idx = 0, for (int it = 0; it < NUM_ELT; it++) {
size_t count = NUM_ELT) const { this->data.elt[it] = op(it);
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx; }
if ( count < NUM_ELT }
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll // Pointer is cast to vector type
for ( int it = 0; it < NUM_ELT; it++ ) { inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
if ( it < count ) { this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
elt_ptr[it] = this->data.elt[it]; }
}
} // Pointer is cast to vector type
} else { inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
this->store_to(elt_ptr); static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
// Pointer is cast to element type. Loads min(count, NUM_ELT)
// elements and any remaining elements are set to zero.
inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
}
} else {
this->load_from(elt_ptr);
}
}
// Pointer is cast to element type. Stores min(count, NUM_ELT)
// elements.
inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
if (it < count) {
elt_ptr[it] = this->data.elt[it];
} }
}
} else {
this->store_to(elt_ptr);
} }
}
inline __device__ void clear() { inline __device__ void clear() {
#pragma unroll #pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) { for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = Elt_type(0.f); this->data.elt[it] = Elt_type(0.f);
}
} }
}
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct InterCTASync { struct InterCTASync {
inline __device__ InterCTASync(int *barrier, inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
int group, : phase_counter_(0),
int num_groups, b0_(barrier + group) // The barrier for this group of CTAs.
int group_size) ,
: phase_counter_(0) b1_(barrier + group + num_groups) // The barrier for this group of CTAs.
, b0_(barrier + group) // The barrier for this group of CTAs. ,
, b1_(barrier + group + num_groups) // The barrier for this group of CTAs. group_size_(group_size) {
, group_size_(group_size) { // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! }
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
inline __device__ void spin_wait_(int *barrier, int step, int expected) { asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); for (int found = -1; found != expected;) {
for ( int found = -1; found != expected; ) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); }
} }
}
inline __device__ void sync() {
inline __device__ void sync() { // ALL THREADS MUST ENTER!
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
// We switch barrier every iteration. int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; // We decrement every other iteration.
// We decrement every other iteration. bool dec = phase_counter_ & 0x2;
bool dec = phase_counter_ & 0x2; int step = dec ? -1 : 1;
int step = dec ? -1 : 1; int expected = dec ? 0 : group_size_;
int expected = dec ? 0 : group_size_; // There are only 4 phases: up/down for b0/b1.
// There are only 4 phases: up/down for b0/b1. phase_counter_ = (phase_counter_ + 1) & 0x3;
phase_counter_ = (phase_counter_ + 1) & 0x3;
if (threadIdx.x == 0) {
if ( threadIdx.x == 0 ) { spin_wait_(barrier, step, expected);
spin_wait_(barrier, step, expected); }
} // CTA waits for thread 0
// CTA waits for thread 0 __syncthreads();
__syncthreads(); }
}
int phase_counter_; int phase_counter_;
int * b0_; int *b0_;
int * b1_; int *b1_;
int group_size_; int group_size_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N> template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> { struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>; using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type; using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES }; enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + enum {
WS_DATA_BYTES }; WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
};
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, template <typename Params>
uint32_t warp_n, uint32_t lane, void * smem) inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem) uint32_t warp_n, uint32_t lane, void *smem)
, inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
, bidn_(bidn) // CTA id within the group. inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) bidn_(bidn) // CTA id within the group.
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} ,
w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
template<typename Op> w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}
inline __device__ T allreduce(T data, const Op &op) {
data = Base::reduce(data, op); template <typename Op>
// We switch workspace every iteration. inline __device__ T allreduce(T data, const Op &op) {
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; data = Base::reduce(data, op);
// We switch workspace every iteration.
// Warp leaders 0 hold the CTA-local results. T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data; // Warp leaders 0 hold the CTA-local results.
} if (this->warp_n_ == 0 && this->lane_ == 0) {
inter_cta_.sync(); workspace[bidn_] = data;
static_assert(CTAS_PER_ROW <= 32); }
T total = Zeros<T>::get(); inter_cta_.sync();
if (this->lane_ < CTAS_PER_ROW) { static_assert(CTAS_PER_ROW <= 32);
total = workspace[this->lane_]; T total = Zeros<T>::get();
} if (this->lane_ < CTAS_PER_ROW) {
total = Reducer<T, 1, 1, 1>::allreduce_(total, op); total = workspace[this->lane_];
}
return total; total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
}
return total;
InterCTASync inter_cta_; }
T * const w0_; InterCTASync inter_cta_;
T * const w1_;
int bidn_; T *const w0_;
T *const w1_;
int bidn_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M> template <typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> { struct Reducer<T, 1, WARPS_M, 1> {
using Type = T; using Type = T;
enum { SMEM_BYTES = 0 }; enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane) {}
template<typename Op>
static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
for ( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op> enum { THREADS_PER_WARP = 32 };
inline __device__ T allreduce(T data, const Op &op) {
return allreduce_(data, op); template <typename Params>
inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: warp_n_(warp_n), lane_(lane) {}
template <typename Op>
static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
data = op(data, warp_shuffle_xor(data, it));
} }
return data;
}
template<typename Op> template <typename Op>
inline __device__ T reduce(T data, const Op &op) { inline __device__ T allreduce(T data, const Op &op) {
// only lane 0 holds the result! return allreduce_(data, op);
#pragma unroll }
for ( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it)); template <typename Op>
} inline __device__ T reduce(T data, const Op &op) {
return data; // only lane 0 holds the result!
#pragma unroll
for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
data = op(data, warp_shuffle_down(data, it));
} }
int warp_n_; return data;
int lane_; }
int warp_n_;
int lane_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N> template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> { struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>; using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T; using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 }; enum { THREADS_PER_WARP = 32 };
template<typename Params> template <typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem) uint32_t warp_n, uint32_t lane, void *smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
, use0_(true), smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])) use0_(true),
, smem1_(smem0_ + WARPS_M * WARPS_N) {} smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
smem1_(smem0_ + WARPS_M * WARPS_N) {}
template<typename Op> template <typename Op>
inline __device__ T allreduce(T data, const Op & op) { inline __device__ T allreduce(T data, const Op &op) {
T * const smem = use0_ ? smem0_ : smem1_; T *const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_; use0_ = !use0_;
data = Base::reduce(data, op); data = Base::reduce(data, op);
if ( this->lane_ == 0 ) { if (this->lane_ == 0) {
smem[this->warp_n_] = data; smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
} }
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for (int it = 0; it < WARPS_N; it++) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op> template <typename Op>
inline __device__ T reduce(T data, const Op &op) { inline __device__ T reduce(T data, const Op &op) {
T * const smem = use0_ ? smem0_ : smem1_; T *const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_; use0_ = !use0_;
// only intra-CTA group leader holds the result! // only intra-CTA group leader holds the result!
data = Base::reduce(data, op); data = Base::reduce(data, op);
if ( this->lane_ == 0 ) { if (this->lane_ == 0) {
smem[this->warp_n_] = data; smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
} }
__syncthreads();
T out = Zeros<T>::get();
if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
for (int it = 0; it < WARPS_N; it++) {
out = op(out, smem[it]);
}
}
return out;
}
T * const smem0_; T *const smem0_;
T * const smem1_; T *const smem1_;
bool use0_; bool use0_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N> template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> { struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>; using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type; using Type = typename Base::Type;
template<typename Params> template <typename Params>
inline __device__ DynamicReducer(const Params & params, inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
uint32_t warp_m, uint32_t warp_n, : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
uint32_t lane, void * smem) inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem) bidn_(bidn) // CTA id within the group.
, inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row) ,
, bidn_(bidn) // CTA id within the group. w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row) w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}
, w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}
template <typename Op>
template<typename Op> inline __device__ T allreduce(T data, const Op &op) {
inline __device__ T allreduce(T data, const Op &op) { // Trivial case
// Trivial case if (inter_cta_.group_size_ == 1) {
if (inter_cta_.group_size_ == 1) { return Base::allreduce(data, op);
return Base::allreduce(data, op); }
}
data = Base::reduce(data, op);
data = Base::reduce(data, op); // We switch workspace every iteration.
// We switch workspace every iteration. T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
// Warp leaders 0 hold the CTA-local results. if (this->warp_n_ == 0 && this->lane_ == 0) {
if ( this->warp_n_ == 0 && this->lane_ == 0 ) { workspace[bidn_] = data;
workspace[bidn_] = data; }
} inter_cta_.sync();
inter_cta_.sync(); T total = Zeros<T>::get();
T total = Zeros<T>::get(); for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
for ( int it = this->lane_; total = op(total, workspace[it]);
it < inter_cta_.group_size_; }
it += THREADS_PER_WARP ) { total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
total = op(total, workspace[it]);
} return total;
total = Reducer<T, 1, 1, 1>::allreduce_(total, op); }
return total; template <typename Op>
} inline __device__ T reduce(T data, const Op &op) {
return allreduce(data, op);
template<typename Op> }
inline __device__ T reduce(T data, const Op &op) {
return allreduce(data, op); InterCTASync inter_cta_;
}
T *const w0_;
InterCTASync inter_cta_; T *const w1_;
int bidn_;
T * const w0_;
T * const w1_;
int bidn_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -625,248 +607,249 @@ A detailed reference on the exact version implemented (with better numerical sta ...@@ -625,248 +607,249 @@ A detailed reference on the exact version implemented (with better numerical sta
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/ */
template<typename T> template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) { // NOLINT(*) inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
// Assume at least leftmost is valid and int num_active) { // NOLINT(*)
// init: step = next_pow2(num_active) / 2 (might get NaN otherwise) // Assume at least leftmost is valid and
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for ( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { #pragma unroll
// Exchange for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
T n_b = warp_shuffle_down(n_a, step); // Exchange
T m_b = warp_shuffle_down(m_a, step); T n_b = warp_shuffle_down(n_a, step);
T m2_b = warp_shuffle_down(m2_a, step); T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. // Update
// Might have different n per thread, otherwise this would simplify :( const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b; const T rn_ab = 1.f / n_ab;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; const T delta = m_a - m_b;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab; n_a = n_ab;
m2_a = m2_ab; m_a = m_ab;
} m2_a = m2_ab;
// Intra-warp broadcast (only lane 0 has valid stats). }
m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0); // Intra-warp broadcast (only lane 0 has valid stats).
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0); m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N> template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats { struct Stats {
// This could be done generically with the Reducer. But then we // This could be done generically with the Reducer. But then we
// would have to exchange 3 instead of 2 fields. // would have to exchange 3 instead of 2 fields.
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane) {}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO(ptredak) rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if ( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result. using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
inter_cta_.sync(); using stats_t = typename BlockStats::stats_t;
T n = Zeros<T>::get(); enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp. template <typename Params>
static_assert(CTAS_PER_ROW <= 32); inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
bidn_(bidn) // CTA id within the group.
,
w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
warp_n_(warp_n),
lane_(lane) {}
// Every warp does the final reduction locally. template <uint32_t N>
if ( lane_ < CTAS_PER_ROW ) { inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t result = workspace[lane_]; constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
n = ELTS_PER_ROW_PER_CTA; // TODO(ptredak) rn is not really needed here..
m = transformer_engine::Get<0>::of<stats_t, T>(result); constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result); stats_t block_stats = block_stats_.compute(elts, block_rn);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
return { m, m2 }; if (warp_n_ == 0 && lane_ == 0) {
workspace[bidn_] = block_stats;
} }
InterCTASync inter_cta_; // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
BlockStats block_stats_; inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if (lane_ < CTAS_PER_ROW) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return {m, m2};
}
stats_t * const w0_; InterCTASync inter_cta_;
stats_t * const w1_; BlockStats block_stats_;
int bidn_;
int warp_n_; stats_t *const w0_;
int lane_; stats_t *const w1_;
int bidn_;
int warp_n_;
int lane_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N> template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> { struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>; using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t; using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params> template <typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem) uint32_t warp_n, uint32_t lane, void *smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
, use0_(true) { smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N; smem1_ = smem0_ + WARPS_M * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N; }
template <uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t *smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
// Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if (lane == 0) {
smem[warp_n] = warp_stats;
} }
__syncthreads();
template<uint32_t N> T n = Zeros<T>::get();
inline __device__ stats_t compute(const T (&elts)[N], const T rn) { T m = Zeros<T>::get();
stats_t * smem = use0_ ? smem0_ : smem1_; T m2 = Zeros<T>::get();
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
// Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if ( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if (lane < WARPS_N) {
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 }; // Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if (lane < WARPS_N) {
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
} }
WarpStats warp_stats_;
stats_t * smem0_; warp_chan_upd_dynamic(m, m2, n, WARPS_N);
stats_t * smem1_;
bool use0_; return {m, m2};
}
WarpStats warp_stats_;
stats_t *smem0_;
stats_t *smem1_;
bool use0_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M> template <typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> { struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type; using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer. // The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>; using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params> enum { SMEM_BYTES = 0 };
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
template<uint32_t N> template <typename Params>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) { inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
auto sum = Sum<T>(); uint32_t warp_n, uint32_t lane, void *smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
T m = Zeros<T>::get(); template <uint32_t N>
#pragma unroll inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
for ( int it = 0; it < N; it++ ) { auto sum = Sum<T>();
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
T m2 = Zeros<T>::get(); T m = Zeros<T>::get();
#pragma unroll #pragma unroll
for ( int it = 0; it < N; it++ ) { for (int it = 0; it < N; it++) {
T diff = (elts[it] - m); m += elts[it];
m2 += diff * diff; }
} m = reducer_.allreduce(m, sum) * rn;
m2 = reducer_.allreduce(m2, sum);
return {m, m2}; T m2 = Zeros<T>::get();
#pragma unroll
for (int it = 0; it < N; it++) {
T diff = (elts[it] - m);
m2 += diff * diff;
} }
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_; Reducer reducer_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int num_elems> template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) { __device__ __forceinline__ float warp_reduce_max(const float m) {
float tmp = m; float tmp = m;
#pragma unroll #pragma unroll
for (int delta = num_elems/2; delta > 0; delta /= 2) { for (int delta = num_elems / 2; delta > 0; delta /= 2) {
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
__builtin_assume(tmp >= 0); __builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0); __builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m); tmp = fmaxf(tmp, other_m);
} }
return tmp; return tmp;
} }
template <int num_warps, typename compute_t> template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
__shared__ float staging[num_warps]; __shared__ float staging[num_warps];
constexpr int warp_size = 32; constexpr int warp_size = 32;
const float my_max = m; const float my_max = m;
const float my_warp_max = warp_reduce_max<warp_size>(my_max); const float my_warp_max = warp_reduce_max<warp_size>(my_max);
if (threadIdx.x % 32 == 0) { if (threadIdx.x % 32 == 0) {
staging[warpid] = my_warp_max; staging[warpid] = my_warp_max;
} }
__syncthreads(); __syncthreads();
compute_t result = 0; compute_t result = 0;
if (warpid == 0) { if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max); result = warp_reduce_max<num_warps>(my_max);
} }
return result; return result;
} }
// Works only on positive values // Works only on positive values
__device__ __forceinline__ void atomicMaxFloat(float * addr, const float value) { __device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value)); atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
} }
// Works only on positive values // Works only on positive values
__device__ __forceinline__ void atomicMinFloat(float * addr, const float value) { __device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value)); atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
} }
template <typename T> template <typename T>
__device__ __forceinline__ void reciprocal(T * value_inv, const T value) { __device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
*value_inv = 1 / value; *value_inv = 1 / value;
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -7,10 +7,11 @@ import warnings ...@@ -7,10 +7,11 @@ import warnings
from enum import Enum from enum import Enum
warnings.filterwarnings( warnings.filterwarnings(
"module", category=DeprecationWarning, module="transformer_engine.common.utils") "module", category=DeprecationWarning, module="transformer_engine.common.utils"
)
class DeprecatedEnum: # pylint: disable=too-few-public-methods class DeprecatedEnum: # pylint: disable=too-few-public-methods
"""DeprecatedEnum""" """DeprecatedEnum"""
def __init__(self, enum_cls, msg): def __init__(self, enum_cls, msg):
...@@ -33,7 +34,7 @@ def deprecate_wrapper(obj, msg): ...@@ -33,7 +34,7 @@ def deprecate_wrapper(obj, msg):
if issubclass(obj, Enum): if issubclass(obj, Enum):
return DeprecatedEnum(obj, msg) return DeprecatedEnum(obj, msg)
class DeprecatedCls(obj): # pylint: disable=too-few-public-methods class DeprecatedCls(obj): # pylint: disable=too-few-public-methods
"""DeprecatedCls""" """DeprecatedCls"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -51,4 +52,5 @@ def deprecate_wrapper(obj, msg): ...@@ -51,4 +52,5 @@ def deprecate_wrapper(obj, msg):
return deprecated return deprecated
raise NotImplementedError( raise NotImplementedError(
f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.") f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}."
)
...@@ -34,22 +34,24 @@ from .sharding import MajorShardingType, ShardingResource, ShardingType ...@@ -34,22 +34,24 @@ from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum from ..common.utils import DeprecatedEnum
MajorShardingType = DeprecatedEnum(MajorShardingType, MajorShardingType = DeprecatedEnum(
"MajorShardingType is deprecating in the near feature.") MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper( ShardingResource = deprecate_wrapper(
ShardingResource, ShardingResource,
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.") "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
__all__ = [ __all__ = [
'NVTE_FP8_COLLECTION_NAME', "NVTE_FP8_COLLECTION_NAME",
'fp8_autocast', "fp8_autocast",
'update_collections', "update_collections",
'get_delayed_scaling', "get_delayed_scaling",
'MeshResource', "MeshResource",
'MajorShardingType', "MajorShardingType",
'ShardingResource', "ShardingResource",
'ShardingType', "ShardingType",
'flax', "flax",
'praxis', "praxis",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment