Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -7,6 +7,10 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include <memory>
#include <mutex>
#include <string>
......@@ -14,10 +18,6 @@
#include <utility>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
......@@ -38,10 +38,10 @@ class Kernel {
public:
Kernel(std::string mangled_name, std::string compiled_code);
~Kernel();
Kernel(const Kernel&) = delete; // move-only
Kernel(Kernel&&) noexcept;
Kernel& operator=(Kernel) noexcept;
friend void swap(Kernel& first, Kernel& second) noexcept;
Kernel(const Kernel &) = delete; // move-only
Kernel(Kernel &&) noexcept;
Kernel &operator=(Kernel) noexcept;
friend void swap(Kernel &first, Kernel &second) noexcept;
/*! \brief Launch CUDA kernel
*
......@@ -57,25 +57,12 @@ class Kernel {
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(int device_id,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
void* arg_ptrs[] = { const_cast<void*>(static_cast<const void*>(&args))... };
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel,
get_function(device_id),
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
shared_mem_bytes,
static_cast<CUstream>(stream),
arg_ptrs,
nullptr);
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
static_cast<CUstream>(stream), arg_ptrs, nullptr);
}
/*! \brief CUDA function for given CUDA device
......@@ -114,7 +101,7 @@ class Kernel {
class KernelManager {
public:
/*! \brief Get singleton instance */
static KernelManager& instance();
static KernelManager &instance();
/*! \brief Compile CUDA kernel for current CUDA device
*
......@@ -126,10 +113,8 @@ class KernelManager {
* \param[in] filename Path to associate with source code,
* primarily for debugging
*/
void compile(const std::string &kernel_label,
const std::string &kernel_name,
const std::string &code,
const std::string &filename);
void compile(const std::string &kernel_label, const std::string &kernel_name,
const std::string &code, const std::string &filename);
/*! \brief Whether CUDA kernel has been compiled for CUDA device
*
......@@ -138,8 +123,7 @@ class KernelManager {
* \return Whether kernel has been compiled
*/
bool is_compiled(const std::string &kernel_label,
int device_id = -1) const;
bool is_compiled(const std::string &kernel_label, int device_id = -1) const;
/*! \brief Launch CUDA kernel on current CUDA device
*
......@@ -154,21 +138,12 @@ class KernelManager {
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(const std::string &kernel_label,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0,
"Attempted to launch RTC kernel before compilation");
kernel_cache_.at(key).launch(device_id,
grid_dim,
block_dim,
shared_mem_bytes,
stream,
NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation");
kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream,
std::forward<ArgTs>(args)...);
}
......@@ -189,8 +164,8 @@ class KernelManager {
KernelManager() = default;
~KernelManager() = default;
KernelManager(const KernelManager&) = delete;
KernelManager& operator=(const KernelManager&) = delete;
KernelManager(const KernelManager &) = delete;
KernelManager &operator=(const KernelManager &) = delete;
/*! \brief Construct key for kernel cache
*
......@@ -199,8 +174,7 @@ class KernelManager {
*
* \return Key for kernel cache
*/
std::string get_kernel_cache_key(const std::string &kernel_label,
int device_id) const;
std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const;
};
} // namespace rtc
......
......@@ -14,23 +14,18 @@
namespace transformer_engine {
/*! \brief Convert to C-style or C++-style string */
template <typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
inline std::string to_string_like(const T &val) {
return std::to_string(val);
}
inline const std::string& to_string_like(const std::string& val) noexcept {
return val;
}
inline const std::string &to_string_like(const std::string &val) noexcept { return val; }
constexpr const char *to_string_like(const char *val) noexcept {
return val;
}
constexpr const char *to_string_like(const char *val) noexcept { return val; }
/*! \brief Convert arguments to strings and concatenate */
template <typename... Ts>
inline std::string concat_strings(const Ts &... args) {
inline std::string concat_strings(const Ts &...args) {
std::string str;
str.reserve(1024); // Assume strings are <1 KB
(..., (str += to_string_like(args)));
......@@ -42,12 +37,9 @@ inline std::string concat_strings(const Ts &... args) {
* This is a convenience wrapper around std::regex_replace.
*/
template <typename T>
inline std::string regex_replace(const std::string &str,
const std::string &pattern,
inline std::string regex_replace(const std::string &str, const std::string &pattern,
const T &replacement) {
return std::regex_replace(str,
std::regex(pattern),
to_string_like(replacement));
return std::regex_replace(str, std::regex(pattern), to_string_like(replacement));
}
} // namespace transformer_engine
......
......@@ -4,6 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../util/system.h"
#include <cstdint>
#include <cstdlib>
#include <filesystem>
......@@ -12,15 +14,14 @@
#include <string>
#include "../common.h"
#include "../util/system.h"
namespace transformer_engine {
namespace {
template <typename T>
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type
getenv_helper(const char *variable, const T &default_value) {
inline typename std::enable_if<std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for numeric types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
......@@ -34,8 +35,8 @@ getenv_helper(const char *variable, const T &default_value) {
}
template <typename T>
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type
getenv_helper(const char *variable, const T &default_value) {
inline typename std::enable_if<!std::is_arithmetic<T>::value, T>::type getenv_helper(
const char *variable, const T &default_value) {
// Implementation for string-like types
const char *env = std::getenv(variable);
if (env == nullptr || env[0] == '\0') {
......@@ -47,13 +48,14 @@ getenv_helper(const char *variable, const T &default_value) {
} // namespace
#define NVTE_INSTANTIATE_GETENV(T, default_value) \
template <> T getenv<T>(const char *variable, \
const T &default_value_) { \
return getenv_helper<T>(variable, default_value_); \
} \
template <> T getenv<T>(const char *variable) { \
return getenv_helper<T>(variable, default_value); \
#define NVTE_INSTANTIATE_GETENV(T, default_value) \
template <> \
T getenv<T>(const char *variable, const T &default_value_) { \
return getenv_helper<T>(variable, default_value_); \
} \
template <> \
T getenv<T>(const char *variable) { \
return getenv_helper<T>(variable, default_value); \
}
NVTE_INSTANTIATE_GETENV(bool, false);
NVTE_INSTANTIATE_GETENV(float, 0.f);
......@@ -69,8 +71,6 @@ NVTE_INSTANTIATE_GETENV(uint64_t, 0);
NVTE_INSTANTIATE_GETENV(std::string, std::string());
NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path());
bool file_exists(const std::string &path) {
return static_cast<bool>(std::ifstream(path.c_str()));
}
bool file_exists(const std::string &path) { return static_cast<bool>(std::ifstream(path.c_str())); }
} // namespace transformer_engine
......@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
#include <type_traits>
#include "../common.h"
#include "../utils.cuh"
......@@ -30,15 +31,13 @@ class VectorizedStorage {
} scratch_;
inline __device__ VectorizedStorage() {}
inline __device__ VectorizedStorage(const VectorizedStorage<DType, n>& y2) {
scratch_.aligned = y2.scratch_.aligned;
}
inline __device__ VectorizedStorage(const LType &y2) {
scratch_.aligned = y2;
inline __device__ VectorizedStorage(const VectorizedStorage<DType, n> &y2) {
scratch_.aligned = y2.scratch_.aligned;
}
inline __device__ VectorizedStorage<DType, n>& operator+=(
const VectorizedStorage<DType, n>& rhs) {
#pragma unroll
inline __device__ VectorizedStorage(const LType &y2) { scratch_.aligned = y2; }
inline __device__ VectorizedStorage<DType, n> &operator+=(
const VectorizedStorage<DType, n> &rhs) {
#pragma unroll
for (int i = 0; i < nvec; ++i) {
scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
}
......@@ -58,7 +57,6 @@ struct select_const<const DType, LType> {
using type = const LType;
};
/* \brief Helper class that enables accessing multiple values of type DType
as 1 value of type LType. Additional aligned template argument
allows performance optimizations if the pointer and the size of
......@@ -67,44 +65,37 @@ struct select_const<const DType, LType> {
template <typename DType, int nvec, bool aligned = false>
class VectorizedAccessor {
public:
using StorageType = VectorizedStorage<typename std::remove_const<DType>::type,
nvec>;
using StorageType = VectorizedStorage<typename std::remove_const<DType>::type, nvec>;
using LType = typename select_const<DType, typename StorageType::LType>::type;
StorageType storage_;
LType* aligned_ptr_;
DType* unaligned_ptr_;
LType *aligned_ptr_;
DType *unaligned_ptr_;
int alignment_;
size_t n_elems_;
inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) {
inline __device__ VectorizedAccessor(DType *const ptr, const size_t size) {
unaligned_ptr_ = ptr;
if (aligned) {
alignment_ = 0;
aligned_ptr_ = reinterpret_cast<LType*>(ptr);
aligned_ptr_ = reinterpret_cast<LType *>(ptr);
n_elems_ = (size + nvec - 1) / nvec;
} else {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
aligned_ptr_ = reinterpret_cast<LType *>(ptr - alignment_);
n_elems_ = (size + alignment_ + nvec - 1) / nvec;
}
}
/* \brief Alignment of the input pointer in elements. */
inline __device__ int alignment() const {
return alignment_;
}
inline __device__ int alignment() const { return alignment_; }
/* \brief Access to separate elements. */
inline __device__ DType* separate() {
return storage_.scratch_.separate;
}
inline __device__ DType *separate() { return storage_.scratch_.separate; }
/* \brief Number of aligned elements that span the entire input tensor. */
inline __device__ size_t num_aligned_elements() const {
return n_elems_;
}
inline __device__ size_t num_aligned_elements() const { return n_elems_; }
/* \brief Load values from the input.
\param id Aligned index of the element.
......@@ -119,7 +110,7 @@ class VectorizedAccessor {
} else {
#pragma unroll
for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
DType *ptr = reinterpret_cast<DType *>(&(aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
storage_.scratch_.separate[j] = *ptr;
......@@ -136,18 +127,16 @@ class VectorizedAccessor {
template <typename DType, int nvec, bool aligned = false>
class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
public:
inline __device__ VectorizedLoader(const DType* ptr, const size_t N) :
VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {
}
inline __device__ VectorizedLoader(const DType *ptr, const size_t N)
: VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {}
};
/* \brief Class used for vectorized writable access. */
template <typename DType, int nvec, bool aligned = false>
class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
public:
inline __device__ VectorizedStorer(DType* ptr, const size_t N) :
VectorizedAccessor<DType, nvec, aligned>(ptr, N) {
}
inline __device__ VectorizedStorer(DType *ptr, const size_t N)
: VectorizedAccessor<DType, nvec, aligned>(ptr, N) {}
/* \brief Store values to the output.
\param id Aligned index of the element.
......@@ -162,7 +151,7 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
} else {
#pragma unroll
for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
DType *ptr = reinterpret_cast<DType *>(&(this->aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
*ptr = this->storage_.scratch_.separate[j];
......@@ -175,34 +164,24 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
constexpr int unary_kernel_threads = 512;
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*OP)(ComputeType, const Param&),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void unary_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__
void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, Param p, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
......@@ -224,43 +203,32 @@ __global__ void unary_kernel(const InputType *input,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*OP)(ComputeType, const Param&),
typename InputType,
typename InputTypeGrad,
template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename InputTypeGrad,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void unary_grad_kernel(const InputTypeGrad *grad,
const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
__launch_bounds__(unary_kernel_threads) __global__
void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output,
const ComputeType *scale, ComputeType *amax, Param p, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
grad_loader.load(tid, N);
#pragma unroll
......@@ -284,25 +252,25 @@ __global__ void unary_grad_kernel(const InputTypeGrad *grad,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
namespace {
inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim,
const int nvec, const int size) {
inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, const int nvec,
const int size) {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
int alignment = (ptr_as_number % (nvec * size)) / size;
return DIVUP(lead_dim + alignment, static_cast<size_t>(nvec));
}
enum class Alignment {
SAME_ALIGNED, // All tensors aligned
SAME_ALIGNED, // All tensors aligned
SAME_UNALIGNED, // All tensors have the same misalignment
DIFFERENT // Tensors have different alignment
DIFFERENT // Tensors have different alignment
};
inline int CalcAlignment(const void *ptr, const int size) {
......@@ -317,10 +285,7 @@ inline int CalcAlignment(const void *ptr, const int size) {
\param ptrs Inputs and Outputs to the operator.
*/
template <typename... T>
Alignment CheckAlignment(const size_t lead_dim,
const int nvec,
const T... ptrs
) {
Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) {
std::vector<int> alignments;
alignments.reserve(sizeof...(T));
......@@ -328,13 +293,12 @@ Alignment CheckAlignment(const size_t lead_dim,
(..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec)));
bool all_same = std::all_of(alignments.cbegin(), alignments.cend(),
[alignments](int val) {return val == alignments.front();});
[alignments](int val) { return val == alignments.front(); });
if (!all_same) {
return Alignment::DIFFERENT;
}
if (alignments.front() == 0 &&
lead_dim % nvec == 0) {
if (alignments.front() == 0 && lead_dim % nvec == 0) {
// all alignment are 0
return Alignment::SAME_ALIGNED;
} else {
......@@ -344,22 +308,15 @@ Alignment CheckAlignment(const size_t lead_dim,
} // namespace
template <int nvec, typename Param,
fp32 (*OP)(const fp32, const Param&),
typename InputType,
template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType,
typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *amax,
const size_t N,
const Param params,
void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t N, const Param params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec,
sizeof(InputType));
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
......@@ -376,32 +333,23 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, params, N, N);
unary_kernel<1, true, fp32, Param, OP>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, params, N, N);
break;
}
}
}
}
template <int nvec, typename Param,
fp32 (*OP)(fp32, const Param&),
typename InputType,
typename InputTypeGrad,
typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad,
const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *amax,
const size_t N,
const Param params,
cudaStream_t stream) {
template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename InputType,
typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax,
const size_t N, const Param params, cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec,
sizeof(InputType));
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
......@@ -418,33 +366,23 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad,
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_grad_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, N);
unary_grad_kernel<1, true, fp32, Param, OP>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, params, N, N);
break;
}
}
}
}
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType,
template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void gated_act_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *amax,
const size_t m,
const size_t n,
const Param p,
const size_t num_aligned_elements) {
__launch_bounds__(unary_kernel_threads) __global__
void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, const size_t m, const size_t n, const Param p,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
......@@ -453,7 +391,7 @@ __global__ void gated_act_kernel(const InputType *input,
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -478,26 +416,18 @@ __global__ void gated_act_kernel(const InputType *input,
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
}
template <int nvec,
typename ComputeType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
typename InputType,
template <int nvec, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename OutputType>
void GatedActivationKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *amax,
const size_t m,
const size_t n,
const Param &p,
void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t m, const size_t n, const Param &p,
cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
......@@ -509,44 +439,34 @@ void GatedActivationKernelLauncher(const InputType *input,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, m, n, p, n);
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, n);
break;
}
}
}
}
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType,
template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &),
ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void dgated_act_kernel(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
const Param p,
const size_t num_aligned_elements) {
__launch_bounds__(unary_kernel_threads) __global__
void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output,
const size_t m, const size_t n, const Param p,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> grad_loader(grad + id_y * n, n);
......@@ -576,23 +496,15 @@ __global__ void dgated_act_kernel(const InputType *grad,
}
}
template <int nvec,
typename ComputeType,
typename Param,
ComputeType (*Activation)(const ComputeType, const Param&),
ComputeType (*Dactivation)(const ComputeType, const Param&),
typename InputType,
template <int nvec, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &),
ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType,
typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
const Param &p,
cudaStream_t stream) {
void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input,
OutputType *output, const size_t m, const size_t n,
const Param &p, cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec,
sizeof(InputType));
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535;
......@@ -601,16 +513,18 @@ void DGatedActivationKernelLauncher(const InputType *grad,
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p,
num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p,
num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n);
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n);
break;
}
}
......
......@@ -31,47 +31,45 @@ constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y};
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b) { // NOLINT(*)
a.x += b.x;
a.y += b.y;
inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*)
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
template <typename T>
struct Sum {
inline __device__ Sum() {}
inline __device__ T operator()(const T &a, const T &b) const {
return a + b;
}
inline __device__ Sum() {}
inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx) {
return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx) {
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx) {
return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx) {
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -81,533 +79,517 @@ namespace transformer_engine {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
template <int BYTES>
struct BytesToType {};
template<>
template <>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
template <>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
template <>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
template <>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
template <>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
template <>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
template <>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
template <typename T>
struct TypeToVec2 {};
template<>
template <>
struct TypeToVec2<float> {
using Type = float2;
using Type = float2;
};
template<>
template <>
struct TypeToVec2<half> {
using Type = half2;
using Type = half2;
};
template<>
template <>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
using Type = nv_bfloat162;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
using InputType = IType;
using InputType2 = IType2;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *act_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
CType *warp_scales_inv;
using InputType = IType;
using InputType2 = IType2;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *act_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
CType *warp_scales_inv;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
template <int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
template <typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
template <>
template <typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
return vec.x;
}
template<>
template<typename T, typename R>
template <>
template <typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
return vec.y;
}
template<>
template<typename T, typename R>
template <>
template <typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
return vec.z;
}
template<>
template<typename T, typename R>
template <>
template <typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
template <typename Src, typename Dst>
struct Converter {
static inline __device__ Dst convert(const Src &from) { return Dst(from); }
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
template <>
struct Converter<float2, half2> {
static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
};
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
template <>
struct Converter<float2, nv_bfloat162> {
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 elt[2];
} tmp;
tmp.elt[0] = __float2bfloat16_rn(x.x);
tmp.elt[1] = __float2bfloat16_rn(x.y);
return tmp.raw;
union {
nv_bfloat162 raw;
nv_bfloat16 elt[2];
} tmp;
tmp.elt[0] = __float2bfloat16_rn(x.x);
tmp.elt[1] = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
template <typename T>
struct Zeros {
static inline __device__ T get() { return T(0.f); }
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
template <>
struct Zeros<float2> {
static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
template <typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using type = Elt_type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Alias_type data;
using Vec_type = typename BytesToType<BYTES>::Type;
using type = Elt_type;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
Alias_type data;
// Pointer is cast to vector type
inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
// Pointer is cast to vector type
inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
// Pointer is cast to element type. Loads min(count, NUM_ELT)
// elements and any remaining elements are set to zero.
inline __device__ void load_from_elts(const void *base_ptr,
size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = (it < count
? elt_ptr[it]
: Elt_type(0.f));
}
} else {
this->load_from(elt_ptr);
}
template <typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
// Pointer is cast to element type. Stores min(count, NUM_ELT)
// elements.
inline __device__ void store_to_elts(void *base_ptr,
size_t idx = 0,
size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
if ( it < count ) {
elt_ptr[it] = this->data.elt[it];
}
}
} else {
this->store_to(elt_ptr);
template <typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = op(it);
}
}
// Pointer is cast to vector type
inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
// Pointer is cast to vector type
inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
// Pointer is cast to element type. Loads min(count, NUM_ELT)
// elements and any remaining elements are set to zero.
inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
}
} else {
this->load_from(elt_ptr);
}
}
// Pointer is cast to element type. Stores min(count, NUM_ELT)
// elements.
inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
if (it < count) {
elt_ptr[it] = this->data.elt[it];
}
}
} else {
this->store_to(elt_ptr);
}
}
inline __device__ void clear() {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = Elt_type(0.f);
}
inline __device__ void clear() {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = Elt_type(0.f);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct InterCTASync {
inline __device__ InterCTASync(int *barrier,
int group,
int num_groups,
int group_size)
: phase_counter_(0)
, b0_(barrier + group) // The barrier for this group of CTAs.
, b1_(barrier + group + num_groups) // The barrier for this group of CTAs.
, group_size_(group_size) {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for ( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync() {
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : group_size_;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if ( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
: phase_counter_(0),
b0_(barrier + group) // The barrier for this group of CTAs.
,
b1_(barrier + group + num_groups) // The barrier for this group of CTAs.
,
group_size_(group_size) {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for (int found = -1; found != expected;) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync() {
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : group_size_;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if (threadIdx.x == 0) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
int group_size_;
int phase_counter_;
int *b0_;
int *b1_;
int group_size_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES +
WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if (this->lane_ < CTAS_PER_ROW) {
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T * const w0_;
T * const w1_;
int bidn_;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum {
WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
};
template <typename Params>
inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
bidn_(bidn) // CTA id within the group.
,
w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}
template <typename Op>
inline __device__ T allreduce(T data, const Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if (this->warp_n_ == 0 && this->lane_ == 0) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if (this->lane_ < CTAS_PER_ROW) {
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T *const w0_;
T *const w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
template <typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane) {}
template<typename Op>
static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
for ( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
return allreduce_(data, op);
enum { THREADS_PER_WARP = 32 };
template <typename Params>
inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: warp_n_(warp_n), lane_(lane) {}
template <typename Op>
static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
for ( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
template <typename Op>
inline __device__ T allreduce(T data, const Op &op) {
return allreduce_(data, op);
}
template <typename Op>
inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
data = op(data, warp_shuffle_down(data, it));
}
int warp_n_;
int lane_;
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true), smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N]))
, smem1_(smem0_ + WARPS_M * WARPS_N) {}
template <typename Params>
inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
use0_(true),
smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
smem1_(smem0_ + WARPS_M * WARPS_N) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op & op) {
T * const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if ( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
template <typename Op>
inline __device__ T allreduce(T data, const Op &op) {
T *const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if (this->lane_ == 0) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for (int it = 0; it < WARPS_N; it++) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
T * const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if ( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
template <typename Op>
inline __device__ T reduce(T data, const Op &op) {
T *const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if (this->lane_ == 0) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
for (int it = 0; it < WARPS_N; it++) {
out = op(out, smem[it]);
}
}
return out;
}
T * const smem0_;
T * const smem1_;
bool use0_;
T *const smem0_;
T *const smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
template<typename Params>
inline __device__ DynamicReducer(const Params & params,
uint32_t bidm, uint32_t bidn,
uint32_t warp_m, uint32_t warp_n,
uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row)
, w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
// Trivial case
if (inter_cta_.group_size_ == 1) {
return Base::allreduce(data, op);
}
data = Base::reduce(data, op);
// We switch workspace every iteration.
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
T total = Zeros<T>::get();
for ( int it = this->lane_;
it < inter_cta_.group_size_;
it += THREADS_PER_WARP ) {
total = op(total, workspace[it]);
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
return allreduce(data, op);
}
InterCTASync inter_cta_;
T * const w0_;
T * const w1_;
int bidn_;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
template <typename Params>
inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
bidn_(bidn) // CTA id within the group.
,
w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}
template <typename Op>
inline __device__ T allreduce(T data, const Op &op) {
// Trivial case
if (inter_cta_.group_size_ == 1) {
return Base::allreduce(data, op);
}
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if (this->warp_n_ == 0 && this->lane_ == 0) {
workspace[bidn_] = data;
}
inter_cta_.sync();
T total = Zeros<T>::get();
for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
total = op(total, workspace[it]);
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
template <typename Op>
inline __device__ T reduce(T data, const Op &op) {
return allreduce(data, op);
}
InterCTASync inter_cta_;
T *const w0_;
T *const w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -625,248 +607,249 @@ A detailed reference on the exact version implemented (with better numerical sta
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/
template<typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) { // NOLINT(*)
// Assume at least leftmost is valid and
// init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for ( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
// Might have different n per thread, otherwise this would simplify :(
const T rn_ab = 1.f / n_ab;
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
int num_active) { // NOLINT(*)
// Assume at least leftmost is valid and
// init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
// Might have different n per thread, otherwise this would simplify :(
const T rn_ab = 1.f / n_ab;
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we
// would have to exchange 3 instead of 2 fields.
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane) {}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO(ptredak) rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if ( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// This could be done generically with the Reducer. But then we
// would have to exchange 3 instead of 2 fields.
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
template <typename Params>
inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
bidn_(bidn) // CTA id within the group.
,
w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
warp_n_(warp_n),
lane_(lane) {}
// Every warp does the final reduction locally.
if ( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
template <uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO(ptredak) rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
return { m, m2 };
if (warp_n_ == 0 && lane_ == 0) {
workspace[bidn_] = block_stats;
}
InterCTASync inter_cta_;
BlockStats block_stats_;
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if (lane_ < CTAS_PER_ROW) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return {m, m2};
}
stats_t * const w0_;
stats_t * const w1_;
int bidn_;
int warp_n_;
int lane_;
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t *const w0_;
stats_t *const w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true) {
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template <typename Params>
inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template <uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t *smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
// Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if (lane == 0) {
smem[warp_n] = warp_stats;
}
__syncthreads();
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
// Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if ( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if (lane < WARPS_N) {
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
return { m, m2 };
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if (lane < WARPS_N) {
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return {m, m2};
}
WarpStats warp_stats_;
stats_t *smem0_;
stats_t *smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
template <typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
enum { SMEM_BYTES = 0 };
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
template <typename Params>
inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void *smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
T m = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < N; it++ ) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
template <uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
T m2 = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
T m = Zeros<T>::get();
#pragma unroll
for (int it = 0; it < N; it++) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
return {m, m2};
T m2 = Zeros<T>::get();
#pragma unroll
for (int it = 0; it < N; it++) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
float tmp = m;
float tmp = m;
#pragma unroll
for (int delta = num_elems/2; delta > 0; delta /= 2) {
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
__builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m);
}
return tmp;
for (int delta = num_elems / 2; delta > 0; delta /= 2) {
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
__builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m);
}
return tmp;
}
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
__shared__ float staging[num_warps];
constexpr int warp_size = 32;
const float my_max = m;
const float my_warp_max = warp_reduce_max<warp_size>(my_max);
if (threadIdx.x % 32 == 0) {
staging[warpid] = my_warp_max;
}
__syncthreads();
compute_t result = 0;
if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max);
}
return result;
__shared__ float staging[num_warps];
constexpr int warp_size = 32;
const float my_max = m;
const float my_warp_max = warp_reduce_max<warp_size>(my_max);
if (threadIdx.x % 32 == 0) {
staging[warpid] = my_warp_max;
}
__syncthreads();
compute_t result = 0;
if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max);
}
return result;
}
// Works only on positive values
__device__ __forceinline__ void atomicMaxFloat(float * addr, const float value) {
atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
}
// Works only on positive values
__device__ __forceinline__ void atomicMinFloat(float * addr, const float value) {
atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
}
template <typename T>
__device__ __forceinline__ void reciprocal(T * value_inv, const T value) {
*value_inv = 1 / value;
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
*value_inv = 1 / value;
}
} // namespace transformer_engine
......
......@@ -7,10 +7,11 @@ import warnings
from enum import Enum
warnings.filterwarnings(
"module", category=DeprecationWarning, module="transformer_engine.common.utils")
"module", category=DeprecationWarning, module="transformer_engine.common.utils"
)
class DeprecatedEnum: # pylint: disable=too-few-public-methods
class DeprecatedEnum: # pylint: disable=too-few-public-methods
"""DeprecatedEnum"""
def __init__(self, enum_cls, msg):
......@@ -33,7 +34,7 @@ def deprecate_wrapper(obj, msg):
if issubclass(obj, Enum):
return DeprecatedEnum(obj, msg)
class DeprecatedCls(obj): # pylint: disable=too-few-public-methods
class DeprecatedCls(obj): # pylint: disable=too-few-public-methods
"""DeprecatedCls"""
def __init__(self, *args, **kwargs):
......@@ -51,4 +52,5 @@ def deprecate_wrapper(obj, msg):
return deprecated
raise NotImplementedError(
f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.")
f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}."
)
......@@ -34,22 +34,24 @@ from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum
MajorShardingType = DeprecatedEnum(MajorShardingType,
"MajorShardingType is deprecating in the near feature.")
MajorShardingType = DeprecatedEnum(
MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
ShardingResource,
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.")
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
__all__ = [
'NVTE_FP8_COLLECTION_NAME',
'fp8_autocast',
'update_collections',
'get_delayed_scaling',
'MeshResource',
'MajorShardingType',
'ShardingResource',
'ShardingType',
'flax',
'praxis',
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"MajorShardingType",
"ShardingResource",
"ShardingType",
"flax",
"praxis",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment