"tests/test_onnx_export.py" did not exist on "e2ad34e9bcaadfdbee441019ddb8c4c786de2973"
Commit 996ea169 authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Inital code drop


Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parents
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/cast.h>
#include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
namespace detail {
struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) {
return value;
}
struct DequantizeParam {
const fp32 *scale_inv;
};
__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
}
} // namespace detail
void fp8_quantize(const Tensor &input,
const Tensor &scale,
Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.dtype != DType::kFloat8E4M3 &&
input.dtype != DType::kFloat8E5M2,
"Input must be in higher precision.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(output->dtype == DType::kFloat8E4M3 ||
output->dtype == DType::kFloat8E5M2,
"Output must have FP8 type.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale must have FP32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale must have 1 element.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX is not allocated.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX must have FP32 type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX must have 1 element.");
NVTE_CHECK(scale_inv->dptr != nullptr, "Inverted scale is not allocated.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");
const size_t N = product(input.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
reinterpret_cast<const fp32*>(scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr),
reinterpret_cast<fp32*>(amax->dptr),
N,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void fp8_dequantize(const Tensor &input,
const Tensor &scale_inv,
Tensor *output,
cudaStream_t stream) {
NVTE_CHECK(input.dtype == DType::kFloat8E4M3 ||
input.dtype == DType::kFloat8E5M2,
"Input must have FP8 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(output->dtype != DType::kFloat8E4M3 &&
output->dtype != DType::kFloat8E5M2,
"Output must be in higher precision.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");
NVTE_CHECK(scale_inv.dptr != nullptr, "Inverted scale is not allocated.");
NVTE_CHECK(scale_inv.dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
NVTE_CHECK(scale_inv.shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");
const size_t N = product(input.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
nullptr,
nullptr,
nullptr,
N,
p,
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(output),
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
#include <type_traits>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
/* \brief Helper class that enables storing multiple values of type DType
as 1 value of type LType.
*/
template <typename DType, int n>
class VectorizedStorage {
public:
using LType = typename transformer_engine::BytesToType<sizeof(DType) * n>::Type;
constexpr static int nvec = n;
union vectorized_storage {
LType aligned;
DType separate[nvec]; // NOLINT(*)
inline __device__ vectorized_storage() {}
inline __device__ ~vectorized_storage() {}
} scratch_;
inline __device__ VectorizedStorage() {}
inline __device__ VectorizedStorage(const VectorizedStorage<DType, n>& y2) {
scratch_.aligned = y2.scratch_.aligned;
}
inline __device__ VectorizedStorage(const LType &y2) {
scratch_.aligned = y2;
}
inline __device__ VectorizedStorage<DType, n>& operator+=(
const VectorizedStorage<DType, n>& rhs) {
#pragma unroll
for (int i = 0; i < nvec; ++i) {
scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
}
return *this;
}
inline __device__ ~VectorizedStorage() {}
};
// Returns const LType is DType is const
template <typename DType, typename LType>
struct select_const {
using type = LType;
};
template <typename DType, typename LType>
struct select_const<const DType, LType> {
using type = const LType;
};
/* \brief Helper class that enables accessing multiple values of type DType
as 1 value of type LType. Additional aligned template argument
allows performance optimizations if the pointer and the size of
the allocation is aligned to sizeof(LType) / sizeof(DType) elements.
*/
template <typename DType, int nvec, bool aligned = false>
class VectorizedAccessor {
public:
using StorageType = VectorizedStorage<typename std::remove_const<DType>::type,
nvec>;
using LType = typename select_const<DType, typename StorageType::LType>::type;
StorageType storage_;
LType* aligned_ptr_;
DType* unaligned_ptr_;
int alignment_;
size_t n_elems_;
inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) {
unaligned_ptr_ = ptr;
if (aligned) {
alignment_ = 0;
aligned_ptr_ = reinterpret_cast<LType*>(ptr);
n_elems_ = (size + nvec - 1) / nvec;
} else {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
n_elems_ = (size + alignment_ + nvec - 1) / nvec;
}
}
/* \brief Alignment of the input pointer in elements. */
inline __device__ int alignment() const {
return alignment_;
}
/* \brief Access to separate elements. */
inline __device__ DType* separate() {
return storage_.scratch_.separate;
}
/* \brief Number of aligned elements that span the entire input tensor. */
inline __device__ size_t num_aligned_elements() const {
return n_elems_;
}
/* \brief Load values from the input.
\param id Aligned index of the element.
\param N size of the tensor.
*/
inline __device__ void load(const size_t id, const size_t N) {
if (aligned) {
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
if (id > 0 && id < n_elems_ - 1) {
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
#pragma unroll
for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
storage_.scratch_.separate[j] = *ptr;
} else {
storage_.scratch_.separate[j] = DType();
}
}
}
}
}
};
/* \brief Class used for vectorized read-only access. */
template <typename DType, int nvec, bool aligned = false>
class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
public:
inline __device__ VectorizedLoader(const DType* ptr, const size_t N) :
VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {
}
};
/* \brief Class used for vectorized writable access. */
template <typename DType, int nvec, bool aligned = false>
class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
public:
inline __device__ VectorizedStorer(DType* ptr, const size_t N) :
VectorizedAccessor<DType, nvec, aligned>(ptr, N) {
}
/* \brief Store values to the output.
\param id Aligned index of the element.
\param N size of the tensor.
*/
inline __device__ void store(const size_t id, const size_t N) {
if (aligned) {
this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
} else {
if (id > 0 && id < this->n_elems_ - 1) {
this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
} else {
#pragma unroll
for (int j = 0; j < nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
*ptr = this->storage_.scratch_.separate[j];
}
}
}
}
}
};
constexpr int unary_kernel_threads = 512;
template <int nvec, bool aligned,
typename ComputeType,
typename Param,
ComputeType (*OP)(ComputeType, const Param&),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void unary_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
Param p,
const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
ComputeType temp = OP(val, p);
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
storer.separate()[i] = static_cast<OutputType>(temp);
}
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
namespace {
inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim,
const int nvec, const int size) {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
int alignment = (ptr_as_number % (nvec * size)) / size;
return DIVUP(lead_dim + alignment, static_cast<size_t>(nvec));
}
enum class Alignment {
SAME_ALIGNED, // All tensors aligned
SAME_UNALIGNED, // All tensors have the same misalignment
DIFFERENT // Tensors have different alignment
};
inline int CalcAlignment(const void *ptr, const int size) {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
return ptr_as_number % size;
}
/* \brief Check alignment of the inputs and outputs when using vectorized accesses.
\param lead_dim Leading dimension of the tensors.
\param other_dim The size of the other dimensions of the tensors.
\param nvec Length of the vector.
\param inputs Inputs to the operator.
\param outputs Outputs of the operator.
*/
template <typename InputType, typename OutputType>
Alignment CheckAlignment(const size_t lead_dim,
const int nvec,
const InputType *input,
const OutputType *output) {
int align = -1;
if (input != nullptr) {
int new_align = CalcAlignment(input, sizeof(InputType) * nvec);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
}
if (output != nullptr) {
int new_align = CalcAlignment(output, sizeof(OutputType) * nvec);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
}
if ((align == 0) &&
(lead_dim % nvec == 0)) {
return Alignment::SAME_ALIGNED;
} else {
return Alignment::SAME_UNALIGNED;
}
}
} // namespace
template <int nvec, typename Param,
fp32 (*OP)(fp32, const Param&),
typename InputType,
typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
const size_t N,
const Param params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec,
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (align) {
case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, params, N, N);
break;
}
}
}
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cassert>
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
// NOLINTBEGIN
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
// NOLINTEND
inline __device__ float2 operator+(const float2 & a, const float2 & b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b) { // NOLINT(*)
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Sum {
inline __device__ Sum() {}
inline __device__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx) {
return __shfl_xor_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx) {
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx) {
return __shfl_down_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx) {
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace transformer_engine {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeToVec2 {};
template<>
struct TypeToVec2<float> {
using Type = float2;
};
template<>
struct TypeToVec2<half> {
using Type = half2;
};
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
}
template<>
template<typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
}
template<>
template<typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
}
template<>
template<typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
};
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 elt[2];
} tmp;
tmp.elt[0] = __float2bfloat16_rn(x.x);
tmp.elt[1] = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using type = Elt_type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) { // NOLINT(*)
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
// Pointer is cast to vector type
inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
// Pointer is cast to vector type
inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
// Pointer is cast to element type. Loads min(count, NUM_ELT)
// elements and any remaining elements are set to zero.
inline __device__ void load_from_elts(const void *base_ptr,
size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = (it < count
? elt_ptr[it]
: Elt_type(0.f));
}
} else {
this->load_from(elt_ptr);
}
}
// Pointer is cast to element type. Stores min(count, NUM_ELT)
// elements.
inline __device__ void store_to_elts(void *base_ptr,
size_t idx = 0,
size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
if ( it < count ) {
elt_ptr[it] = this->data.elt[it];
}
}
} else {
this->store_to(elt_ptr);
}
}
inline __device__ void clear() {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = Elt_type(0.f);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct InterCTASync {
inline __device__ InterCTASync(int *barrier,
int group,
int num_groups,
int group_size)
: phase_counter_(0)
, b0_(barrier + group) // The barrier for this group of CTAs.
, b1_(barrier + group + num_groups) // The barrier for this group of CTAs.
, group_size_(group_size) {
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for ( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync() {
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : group_size_;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if ( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
int group_size_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES +
WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if (this->lane_ < CTAS_PER_ROW) {
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T * const w0_;
T * const w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane) {}
template<typename Op>
static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
for ( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
return allreduce_(data, op);
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
for ( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true), smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N]))
, smem1_(smem0_ + WARPS_M * WARPS_N) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op & op) {
T * const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if ( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
T * const smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if ( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for ( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
}
T * const smem0_;
T * const smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
template<typename Params>
inline __device__ DynamicReducer(const Params & params,
uint32_t bidm, uint32_t bidn,
uint32_t warp_m, uint32_t warp_n,
uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row)
, w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}
template<typename Op>
inline __device__ T allreduce(T data, const Op &op) {
// Trivial case
if (inter_cta_.group_size_ == 1) {
return Base::allreduce(data, op);
}
data = Base::reduce(data, op);
// We switch workspace every iteration.
T * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if ( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
T total = Zeros<T>::get();
for ( int it = this->lane_;
it < inter_cta_.group_size_;
it += THREADS_PER_WARP ) {
total = op(total, workspace[it]);
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
template<typename Op>
inline __device__ T reduce(T data, const Op &op) {
return allreduce(data, op);
}
InterCTASync inter_cta_;
T * const w0_;
T * const w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) { // NOLINT(*)
// Assume at least leftmost is valid and
// init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for ( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
// Might have different n per thread, otherwise this would simplify :(
const T rn_ab = 1.f / n_ab;
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we
// would have to exchange 3 instead of 2 fields.
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane) {}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO(ptredak) rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t * const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if ( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if ( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return { m, m2 };
}
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t * const w0_;
stats_t * const w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true) {
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
// Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if ( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if (lane < WARPS_N) {
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = transformer_engine::Get<0>::of<stats_t, T>(result);
m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 };
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params>
inline __device__ Stats(const Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
T m = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < N; it++ ) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
T m2 = Zeros<T>::get();
#pragma unroll
for ( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
float tmp = m;
#pragma unroll
for (int delta = num_elems/2; delta > 0; delta /= 2) {
const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
__builtin_assume(tmp >= 0);
__builtin_assume(other_m >= 0);
tmp = fmaxf(tmp, other_m);
}
return tmp;
}
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
__shared__ float staging[num_warps];
constexpr int warp_size = 32;
const float my_max = m;
const float my_warp_max = warp_reduce_max<warp_size>(my_max);
if (threadIdx.x % 32 == 0) {
staging[warpid] = my_warp_max;
}
__syncthreads();
compute_t result = 0;
if (warpid == 0) {
const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
result = warp_reduce_max<num_warps>(my_max);
}
return result;
}
// Works only on positive values
__device__ __forceinline__ void atomicMaxFloat(float * addr, const float value) {
atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
}
// Works only on positive values
__device__ __forceinline__ void atomicMinFloat(float * addr, const float value) {
atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
}
template <typename T>
__device__ __forceinline__ void reciprocal(T * value_inv, const T value) {
*value_inv = 1 / value;
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for pyTorch"""
from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
from .module import LayerNorm
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Enums for e2e transformer"""
import torch
import transformer_engine_extensions as tex
"""
This is a map: torch.dtype -> int
Used for passing dtypes into cuda
extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
torch.int8: tex.DType.kByte,
torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16,
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding")
AttnTypes = ("self", "cross")
LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None)
dist_group_type = torch._C._distributed_c10d.ProcessGroup
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from .constants import TE_DType
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
A_dtype: tex.DType,
B: torch.Tensor,
B_scale_inv: torch.Tensor,
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
fp32_output: bool = False,
use_split_accumulator: bool = False,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor()
return_output = False
if out is None:
out = torch.empty(
B.shape[0],
A.shape[0],
dtype=torch.float32 if fp32_output else out_dtype,
device="cuda",
)
return_output = True
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
tex.te_gemm(
A,
A_scale_inv,
A_dtype,
True, # transa
B,
B_scale_inv,
B_dtype,
False, # transb
out,
out_dtype,
bias if use_bias else empty_tensor,
empty_tensor,
False, # grad
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)
if return_output:
return out
return None
def gemm(
A: torch.Tensor,
B: torch.Tensor,
dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
gelu_input: Optional[torch.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
fp32_output: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
return_output = False
if out is None:
out = torch.empty(
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
dtype=torch.float32 if fp32_output else dtype,
device="cuda",
)
return_output = True
if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = empty_tensor
if grad and use_bias:
grad_bias = torch.empty(
B.shape[1], dtype=torch.float32 if fp32_output else dtype, device="cuda"
)
else:
grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor
tex.te_gemm(
A,
empty_tensor,
input_dtype,
transa,
B,
empty_tensor,
input_dtype,
transb,
out,
output_dtype,
grad_bias if grad else bias,
gelu_input,
grad,
workspace,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
def fp8_cast_transpose_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
"""Cast + Transpose with FP8 output"""
return_outputs = False
if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.int8)
transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8
)
return_outputs = True
tex.fused_cast_transpose(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)
if return_outputs:
return cast_out, transpose_out
return None
def fp8_cast_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD with FP8 output"""
return tex.fused_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def fp8_cast_transpose_bgrad_dgelu_fused(
grad_output: torch.Tensor,
gelu_input: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD + DGELU with FP8 output"""
return tex.fused_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def fp8_gelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""GeLU with FP8 output"""
return tex.fp8_gelu(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def layernorm_fwd_fp8(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
return tex.layernorm_fwd_fp8(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""Cast input to FP8"""
return tex.cast_to_fp8(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
)
def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
itype: tex.DType,
otype: tex.DType,
) -> torch.Tensor:
"""Cast input from FP8"""
return tex.cast_from_fp8(
inp,
fp8_meta_tensor.scale_inv[fp8_tensor],
itype,
otype,
)
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
#include "transformer_engine/transformer_engine.h"
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string &fp8_recipe) {
// if e4m3 or hybrid + forward
if ( (fp8_recipe == "E4M3") || ( (fp8_recipe == "HYBRID") && e4m3_if_hybrid ) ) {
return transformer_engine::DType::kFloat8E4M3;
}
return transformer_engine::DType::kFloat8E5M2;
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr,
const NVTEShape& shape,
const transformer_engine::DType type) {
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type) {
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}
size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (auto s : shape) {
ret *= s;
}
return ret;
}
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
bool init_to_zeros) {
auto size = shape.ndim;
if (size == 2 && init_to_zeros) {
return at::zeros({static_cast<int64_t>(shape.data[0]),
static_cast<int64_t>(shape.data[1])},
at::CUDA(GetATenDType(type)));
} else if (size == 2) {
return at::empty({static_cast<int64_t>(shape.data[0]),
static_cast<int64_t>(shape.data[1])},
at::CUDA(GetATenDType(type)));
} else if (size == 1 && init_to_zeros) {
return at::zeros({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
} else if (size == 1) {
return at::empty({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
}
NVTE_CHECK(false, "Should never reach here! func: allocateSpace");
}
at::Tensor allocateTorchTensor(int M,
int N,
transformer_engine::DType dtype
) {
return at::empty({static_cast<int64_t>(M), static_cast<int64_t>(N)},
at::CUDA(GetATenDType(dtype)));
}
at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype
) {
return at::empty({static_cast<int64_t>(M)},
at::CUDA(GetATenDType(dtype)));
}
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount,
const bool fp8_out
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type);
auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type);
auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out);
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data(), fp8_out);
}
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape,
output_cast_type);
auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
output_transpose_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape,
scale_inv_type);
nvte_cast_transpose(input_cu.data(), scale_cu.data(),
output_cast_cu.data(), output_transpose_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
nvte_gelu(input_cu.data(), output_cu.data(), scale_cu.data(),
amax_cu.data(), scale_inv_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(),
dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(),
dbias_cu.data(), scale_inv_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
gelu_input_type);
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdexcept>
#include <memory>
#include <iomanip>
#include <random>
#include <cstring>
#include <vector>
#include <iostream>
namespace transformer_engine {
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
public:
at::Tensor scale;
at::Tensor scale_inv;
at::Tensor amax_history;
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1,
GEMM2_INPUT = 2,
GEMM2_WEIGHT = 3
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
GRAD_OUTPUT1 = 0,
GRAD_OUTPUT2 = 1
};
} // namespace transformer_engine
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string &fp8_recipe);
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
return at::kFloat;
case transformer_engine::DType::kFloat16:
return at::kHalf;
case transformer_engine::DType::kBFloat16:
return at::kBFloat16;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return at::kByte;
default:
NVTE_ERROR("Invalid type");
}
}
inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
switch (t) {
case at::kHalf:
return transformer_engine::DType::kFloat16;
case at::kFloat:
return transformer_engine::DType::kFloat32;
case at::kBFloat16:
return transformer_engine::DType::kBFloat16;
default:
NVTE_ERROR("Invalid type");
}
}
inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
return static_cast<transformer_engine::DType>(DType_value);
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type
);
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape,
const transformer_engine::DType type
);
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);
size_t product(const std::vector<size_t> &shape);
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
bool init_to_zeros = false);
at::Tensor allocateTorchTensor(int M,
int N,
transformer_engine::DType dtype
);
at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype
);
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount,
const bool fp8_out
);
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
);
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
transformer_engine::DType D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type);
auto te_A_scale_inverse = makeTransformerEngineTensor(A_scale_inverse.data_ptr(), {1},
GetTransformerEngineDType(
A_scale_inverse.scalar_type()));
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type);
auto te_B_scale_inverse = makeTransformerEngineTensor(B_scale_inverse.data_ptr(), {1},
GetTransformerEngineDType(
B_scale_inverse.scalar_type()));
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
GetTransformerEngineDType(bias.scalar_type()));
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_gemm(te_A.data(),
te_A_scale_inverse.data(),
te_B.data(),
te_B_scale_inverse.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
at::cuda::getCurrentCUDAStream());
}
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
DType inp_type = GetTransformerEngineDType(input.scalar_type());
dispatch_cast_transpose_fusion(
input.data_ptr(), {M, N}, inp_type,
scale.data_ptr(), {1}, DType::kFloat32,
input_cast.data_ptr(), {M, N}, otype,
input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32);
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
dispatch_bgrad_cast_transpose_fusion(
grad_output.data_ptr(), {M, N}, grad_output_type,
scale.data_ptr(), {1}, DType::kFloat32,
grad_output_cast.data_ptr(), {M, N}, otype,
grad_output_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
grad_bias.data_ptr(), {N}, grad_output_type,
scale_inv.data_ptr(), {1}, DType::kFloat32);
return {grad_bias, grad_output_cast, grad_output_transpose};
}
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto dgelu_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
dispatch_bgrad_dgelu_cast_transpose_fusion(
grad_output.data_ptr(), {M, N}, grad_output_type,
gelu_input.data_ptr(), {M, N}, grad_output_type,
scale.data_ptr(), {1}, DType::kFloat32,
dgelu.data_ptr(), {M, N}, otype,
dgelu_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
grad_bias.data_ptr(), {N}, grad_output_type,
scale_inv.data_ptr(), {1}, DType::kFloat32);
return {grad_bias, dgelu, dgelu_transpose};
}
at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_transpose =
allocateTorchTensor(input.size(1),
input.size(0),
DType::kByte);
dispatch_transpose(input.data_ptr(), {M, N}, otype,
input_transpose.data_ptr(), {N, M}, otype);
return input_transpose;
}
at::Tensor fp8_gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
DType input_type = GetTransformerEngineDType(input.scalar_type());
auto output =
allocateTorchTensor(input.size(0),
input.size(1),
DType::kByte);
dispatch_gelu(input.data_ptr(), {M, N}, input_type,
scale.data_ptr(), {1}, DType::kFloat32,
output.data_ptr(), {M, N}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32);
return output;
}
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma
) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
auto dbeta = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(),
dgamma_part.shape(),
dgamma_part.dtype());
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(),
dbeta_part.shape(),
dbeta_part.dtype());
// Actual call to bwd kernel.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
return { dx, dgamma, dbeta };
}
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
dispatch_layernorm(
input.data_ptr(), {N, H}, itype,
weight.data_ptr(), {H}, itype,
bias.data_ptr(), {H}, itype,
scale.data_ptr(), {1}, DType::kFloat32,
eps,
ln_out.data_ptr(), {N, H}, otype,
mu.data_ptr(), {N}, DType::kFloat32,
rsigma.data_ptr(), {N}, DType::kFloat32,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
true);
return {ln_out, mu, rsigma};
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
dispatch_layernorm(input.data_ptr(), {N, H}, itype,
weight.data_ptr(), {H}, itype,
bias.data_ptr(), {H}, itype,
nullptr, {1}, DType::kFloat32,
eps,
ln_out.data_ptr(), {N, H}, itype,
mu.data_ptr(), {N}, DType::kFloat32,
rsigma.data_ptr(), {N}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
false);
return {ln_out, mu, rsigma};
}
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype);
auto scale_cu = makeTransformerEngineTensor(scale.data_ptr(), {1}, DType::kFloat32);
auto amax_cu = makeTransformerEngineTensor(amax.data_ptr(), {1}, DType::kFloat32);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_quantize(input_cu.data(), scale_cu.data(), output_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype);
auto output_cu = makeTransformerEngineTensor(output);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_dequantize(input_cu.data(), scale_inv_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale)
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
py::enum_<transformer_engine::DType>(m, "DType")
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
transformer_engine::DType D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator
);
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
);
at::Tensor fp8_gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma
);
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
);
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
transformer_engine::DType otype
);
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&transformer_engine::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <c10/macros/Macros.h>
#include <cfloat>
#include <limits>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2*) dst) = *((half2*) src); // NOLINT(*)
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))
+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count
+ it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))
+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y)
* WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count
+ it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count
+ it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count
+ it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] *
output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count
+ it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
break;
default:
break;
}
}
}
#endif // TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2*) dst) = *((half2*) src); // NOLINT(*)
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) {
*dst = 0.0;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data,
src + i*element_count*stride
+ it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride
+ it * WARP_SIZE,
out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride
+ it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad,
grad + i * element_count * stride
+ it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output,
output + i * element_count * stride
+ it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] *
output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride
+ it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default:
break;
}
}
}
#endif // TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment