Commit 996ea169 authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Inital code drop


Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parents
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include "../common.h"
namespace transformer_engine {
void cublas_gemm(void* A,
void* A_scale_inverse,
void* B,
void *B_scale_inverse,
void* D,
void* bias_ptr,
void* pre_gelu_out,
int m, int n, int k,
int lda, int ldb, int ldd,
cudaDataType_t A_type,
cudaDataType_t B_type,
cudaDataType_t D_type,
cudaDataType_t bias_type,
cublasOperation_t transa,
cublasOperation_t transb,
bool bias,
bool gelu,
bool grad,
void* workspace,
size_t workspaceSize,
bool use_fp8,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream
) {
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion is unavailable right now.
if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle;
NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Ddesc = nullptr;
cublasLtMatmulPreference_t preference = nullptr;
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
cublasComputeType_t gemm_compute_type = (A_type == CUDA_R_8F_E5M2 || B_type == CUDA_R_8F_E5M2)
? CUBLAS_COMPUTE_32F
: CUBLAS_COMPUTE_32F_FAST_TF32;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type,
transa == CUBLAS_OP_N ? m : k,
transa == CUBLAS_OP_N ? k : m,
lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type,
transb == CUBLAS_OP_N ? k : n,
transb == CUBLAS_OP_N ? n : k,
ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode,
sizeof(fastAccuMode)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse,
sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (bias) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
}
if (bias && gelu) {
if (grad) {
epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
} else {
epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = CUBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = CUBLASLT_EPILOGUE_DGELU;
} else {
epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, 1, &heuristicResult,
&returnedResults));
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&heuristicResult.algo, /* algo */
workspace, /* workspace */
workspaceSize,
stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
}
} // namespace transformer_engine
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 ||
t == transformer_engine::DType::kFloat8E5M2;
}
} // namespace
void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor A_scale_inverse,
const NVTETensor B,
const NVTETensor B_scale_inverse,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream) {
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
const Tensor *Ainvscale = reinterpret_cast<const Tensor*>(A_scale_inverse);
const Tensor *Binvscale = reinterpret_cast<const Tensor*>(B_scale_inverse);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->shape[0] : inputA->shape[1];
const int k = transa ? inputA->shape[1] : inputA->shape[0];
const int n = transb ? inputB->shape[1] : inputB->shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA->dptr, Ainvscale->dptr,
inputB->dptr, Binvscale->dptr,
outputD->dptr, biasTensor->dptr,
outputGelu->dptr,
m, n, k,
lda, ldb, ldd,
get_cuda_dtype(inputA->dtype),
get_cuda_dtype(inputB->dtype),
get_cuda_dtype(outputD->dtype),
get_cuda_dtype(biasTensor->dtype),
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
biasTensor->dptr != nullptr,
outputGelu->dptr != nullptr,
grad, wspace->dptr,
wspace->shape[0],
is_fp8_dtype(inputA->dtype) || is_fp8_dtype(inputB->dtype),
accumulate, use_split_accumulator,
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file activation.h
* \brief Activation functions.
*/
#ifndef TRANSFORMER_ENGINE_ACTIVATION_H_
#define TRANSFORMER_ENGINE_ACTIVATION_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Compute GELU activation of the input.
*
* \param[in] input Input tensor for GELU activation.
* \param[out] output Output tensor.
* \param[in] scale Scaling factor of the output tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_gelu(const NVTETensor input,
NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_ACTIVATION_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
*/
#ifndef TRANSFORMER_ENGINE_CAST_H_
#define TRANSFORMER_ENGINE_CAST_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Cast tensor to FP8.
*
* \param[in] input Input tensor to be cast.
* \param[in] scale Scaling factor of the output tensor.
* \param[out] output Output FP8 tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
/*! \brief Cast tensor from FP8.
*
* \param[in] input Input tensor to be cast.
* \param[in] scale_inv Inverse of the input's scaling factor.
* \param[out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_CAST_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gemm.h
* \brief Functions for matrix multiplication.
*/
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] A_scale_inverse The inverse of A matrix' scaling factor.
* \param[in] B The B matrix.
* \param[in] B_scale_inverse The inverse of B matrix' scaling factor.
* \param[out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor A_scale_inverse,
const NVTETensor B,
const NVTETensor B_scale_inverse,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_GEMM_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file layer_norm.h
* \brief LayerNorm functions.
*/
#ifndef TRANSFORMER_ENGINE_LAYER_NORM_H_
#define TRANSFORMER_ENGINE_LAYER_NORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Compute LayerNorm on the input.
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] scale Scaling factor used for output.
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] fp8_out Whether to output FP8.
*/
void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const NVTETensor scale,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv,
bool fp8_out);
/*! \brief Compute backward of LayerNorm.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[out] dbeta_part Storage for partial bias gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_LAYER_NORM_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_LOGGING_H_
#define TRANSFORMER_ENGINE_LOGGING_H_
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <string>
#include <stdexcept>
#define NVTE_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
#define NVTE_CHECK(x, ...) \
do { \
if (!(x)) { \
NVTE_ERROR(std::string("Assertion failed: " #x ". ") + std::string(__VA_ARGS__)); \
} \
} while (false)
namespace {
inline void check_cuda_(cudaError_t status) {
if ( status != cudaSuccess ) {
NVTE_ERROR("CUDA Error: " + std::string(cudaGetErrorString(status)));
}
}
inline void check_cublas_(cublasStatus_t status) {
if ( status != CUBLAS_STATUS_SUCCESS ) {
NVTE_ERROR("CUBLAS Error: " + std::string(cublasGetStatusString(status)));
}
}
} // namespace
#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); }
#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); }
#endif // TRANSFORMER_ENGINE_LOGGING_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file transformer_engine.h
* \brief Base classes and functions of Transformer Engine API.
*/
#ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#include <stddef.h>
#include <cuda_runtime_api.h>
#ifdef __cplusplus
extern "C" {
#endif
/*! \enum NVTEDType
* \brief TE datatype.
*/
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEFloat32 = 2, /*!< 32-bit float */
kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */
};
/*! \struct NVTEShape
* \brief Shape of the tensor.
*/
struct NVTEShape {
/*! \brief Shape data, of size ndim. */
const size_t *data;
/*! \brief Number of dimensions. */
size_t ndim;
};
/*! \brief TE Tensor type
*
* NVTETensor is a contiguous tensor type storing a pointer
* to data of a given shape and type. It does not own the
* memory it points to.
*/
typedef void* NVTETensor;
/*! \brief Create a new TE tensor.
*
* Create a new TE tensor with a given shape, datatype and data.
* TE tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
*
* \return A new TE tensor.
*/
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype);
/*! \brief Destroy a TE tensor.
*
* Since the TE tensor does not own memory, the underlying
* data is not freed during this operation.
*
* \param[in] tensor Tensor to be destroyed.
*/
void nvte_destroy_tensor(NVTETensor tensor);
/*! \brief Get a tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return A data type of the input tensor.
*/
NVTEDType nvte_tensor_type(const NVTETensor tensor);
/*! \brief Get a tensor's data shape.
*
* \param[in] tensor Tensor.
*
* \return A shape of the input tensor.
*/
NVTEShape nvte_tensor_shape(const NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's data.
*
* \param[in] tensor Tensor.
*
* \return A raw pointer to tensor's data.
*/
void *nvte_tensor_data(const NVTETensor tensor);
#ifdef __cplusplus
} // extern "C"
#include <vector>
/*! \namespace transformer_engine
* \brief Namespace containing C++ API of Transformer Engine.
*/
namespace transformer_engine {
/*! \enum DType
* \brief TE datatype.
*/
enum class DType {
kByte = 0,
kInt32 = 1,
kFloat32 = 2,
kFloat16 = 3,
kBFloat16 = 4,
kFloat8E4M3 = 5,
kFloat8E5M2 = 6,
kNumTypes
};
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
class TensorWrapper {
public:
/*! \brief Constructs new TensorWrapper.
*
* Create a new TE tensor with a given shape, datatype and data.
* TE tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype) :
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype))) {}
/*! \brief Constructs new TensorWrapper.
*
* Create a new TE tensor with a given shape, datatype and data.
* TE tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype) :
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype) {}
/*! \brief Constructs new empty TensorWrapper.
*
* Create a new empty TE tensor which holds nothing.
*/
TensorWrapper() : TensorWrapper(nullptr, std::vector<size_t>(), DType::kFloat32) {}
/*! \brief TensorWrapper destructor. */
~TensorWrapper() {
nvte_destroy_tensor(tensor_);
}
TensorWrapper& operator=(const TensorWrapper &other) = delete;
TensorWrapper(const TensorWrapper &other) = delete;
/*! \brief Constructs new TensorWrapper from existing TensorWrapper.
*
* Pass an existing TE tensor to a new TensorWrapper.
*
* \param[in,out] other The source of the data.
*/
TensorWrapper(TensorWrapper &&other) {
tensor_ = other.tensor_;
other.tensor_ = nullptr;
}
/*! \brief Assign the data from existing TensorWrapper.
*
* Change ownership of an existing TE tensor.
*
* \param[in,out] other The source of the data.
*/
TensorWrapper& operator=(TensorWrapper &&other) {
if (this == &other) return *this;
nvte_destroy_tensor(tensor_);
tensor_ = other.tensor_;
other.tensor_ = nullptr;
return *this;
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
*/
NVTETensor data() const noexcept {
return tensor_;
}
/*! \brief Get the shape of this TensorWrapper.
*
* \return Shape of this TensorWrapper.
*/
const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_tensor_shape(tensor_);
}
/*! \brief Get the data type of this TensorWrapper.
*
* \return Data type of this TensorWrapper.
*/
DType dtype() const noexcept {
if (tensor_ == nullptr) return DType::kNumTypes;
return static_cast<DType>(nvte_tensor_type(tensor_));
}
/*! \brief Get a raw pointer to the tensor's data.
*
* \return A raw pointer to tensor's data.
*/
void *dptr() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_data(tensor_);
}
private:
/*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr;
};
} // namespace transformer_engine
#endif
#endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file transpose.h
* \brief Functions handling transposes.
*/
#ifndef TRANSFORMER_ENGINE_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_TRANSPOSE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Cast and transpose the input.
*
* This function casts the input and produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
/*! \brief Transpose the input.
*
* \param[in] input Input tensor of shape [N, H].
* \param[out] transposed_output Result of the transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension.
*
* This function casts the input and produces 3 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
* - `dbias` is the result of the reduction of the input along the first dimension.
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] dbias Result of the reduction of the input along the
* first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of GELU operation on the input, then cast and transpose. Additionally,
* reduce the result of the GELU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dGELU(input))`
* - `transposed_output` is equal to `transpose(cast(dGELU(input)))`
* - `dbias` is equal to `reduce(dGELU(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gelu_input Tensor used as input to the forward of GELU operation.
* Shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] dbias Result of the reduction of the dGELU(input) along the
* first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_TRANSPOSE_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <vector>
#include <unordered_map>
#include "../common.h"
namespace transformer_engine {
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
int multiprocessorCount;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Size of CTA group.
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
, fp8_out(false) {}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
// Scaling factor
void *scale;
// Scaling factor inverse,
// needed for cublasLt fp8 gemm
void *scale_inv;
// AMax output
void *amax;
// Whether to compute scale and amax
bool fp8_out;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
using FwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, FwdFunction>>;
using BwdGeneralRegistry = std::unordered_map<FunctionKey, std::map<uint64_t, BwdFunction>>;
extern FwdTunedRegistry FWD_TUNED_FUNCS;
extern BwdTunedRegistry BWD_TUNED_FUNCS;
extern FwdGeneralRegistry FWD_GENERAL_FUNCS;
extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar{
explicit FwdTunedRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar{
explicit FwdGeneralRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar{
explicit BwdTunedRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar{
explicit BwdGeneralRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size);
} // namespace layer_norm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/layer_norm.h>
#include <vector>
#include "ln.h"
#include "../common.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
bf16 fp32 bf16 fp8
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
// Create registries and provide runtime versions of config hash functions.
FwdTunedRegistry FWD_TUNED_FUNCS;
BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(DType dtype) {
if ( dtype == DType::kFloat16 ) {
return TypeId<fp16>::Value;
} else if ( dtype == DType::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if ( dtype == DType::kFloat32 ) {
return TypeId<fp32>::Value;
} else if ( dtype == DType::kFloat8E4M3 ) {
return TypeId<fp8e4m3>::Value;
} else {
NVTE_ERROR("Type not supported.");
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) |
(get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size,
uint32_t batch_size) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0
&& layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size,
uint32_t batch_size) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (auto s : shape) {
ret *= s;
}
return ret;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const Tensor& scale,
const float epsilon,
Tensor* z,
Tensor* mu,
Tensor* rsigma,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier,
Tensor* amax,
Tensor *scale_inv,
bool fp8_out
) {
auto itype = x.dtype;
auto wtype = gamma.dtype;
auto otype = z->dtype;
auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.shape.size() == 2);
const size_t rows = x.shape[0];
const size_t cols = x.shape[1];
auto hidden_size = gamma.shape[0];
NVTE_CHECK(gamma.shape == beta.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->dptr != nullptr);
NVTE_CHECK(z->shape == x.shape);
NVTE_CHECK(mu->shape == std::vector<size_t>{ rows });
NVTE_CHECK(mu->dtype == ctype);
NVTE_CHECK(rsigma->shape == std::vector<size_t>{ rows });
NVTE_CHECK(rsigma->dtype == ctype);
if (fp8_out) {
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(scale.dptr != nullptr);
NVTE_CHECK(scale.dtype == ctype);
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 });
NVTE_CHECK(amax->dptr != nullptr);
NVTE_CHECK(amax->dtype == ctype);
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 });
NVTE_CHECK(scale_inv->dptr != nullptr);
NVTE_CHECK(scale_inv->dtype == ctype);
}
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.dptr;
params.mu = mu->dptr;
params.rs = rsigma->dptr;
params.gamma = gamma.dptr;
params.beta = beta.dptr;
params.z = z->dptr;
params.epsilon = epsilon;
params.amax = amax->dptr;
params.scale = scale.dptr;
params.scale_inv = scale_inv->dptr;
params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (workspace->dptr == nullptr) {
NVTE_CHECK(barrier->dptr == nullptr);
workspace->dtype = layer_norm::DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size };
return;
}
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr);
}
// Clear buffers
if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(amax->shape) *
typeToSize(amax->dtype), stream);
}
if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) *
typeToSize(barrier->dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void layernorm_bwd(const Tensor& dz,
const Tensor& x,
const Tensor& mu,
const Tensor& rsigma,
const Tensor& gamma,
Tensor* dx,
Tensor* dgamma,
Tensor* dbeta,
Tensor* dgamma_part,
Tensor* dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier
) {
using namespace transformer_engine;
auto itype = x.dtype;
auto wtype = gamma.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.dtype == otype);
NVTE_CHECK(mu.dtype == ctype);
NVTE_CHECK(rsigma.dtype == ctype);
NVTE_CHECK(x.shape.size() == 2);
NVTE_CHECK(dz.shape == x.shape);
auto rows = x.shape[0];
auto cols = x.shape[1];
auto hidden_size = gamma.shape[0];
NVTE_CHECK(mu.shape[0] == rows);
NVTE_CHECK(mu.shape == rsigma.shape);
NVTE_CHECK(gamma.shape[0] == cols);
NVTE_CHECK(dx->shape == x.shape);
NVTE_CHECK(dx->dtype == x.dtype);
NVTE_CHECK(dx->dptr != nullptr);
NVTE_CHECK(dgamma->shape == gamma.shape);
NVTE_CHECK(dgamma->dtype == gamma.dtype);
NVTE_CHECK(dgamma->dptr != nullptr);
NVTE_CHECK(dbeta->shape == gamma.shape);
NVTE_CHECK(dbeta->dtype == gamma.dtype);
NVTE_CHECK(dbeta->dptr != nullptr);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);
// Set the kernel runtime parameters.
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.dptr;
params.mu = mu.dptr;
params.rs = rsigma.dptr;
params.gamma = gamma.dptr;
params.dz = dz.dptr;
params.dx = dx->dptr;
params.dbeta = dbeta->dptr;
params.dgamma = dgamma->dptr;
params.dbeta_part = dbeta_part->dptr;
params.dgamma_part = dgamma_part->dptr;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->dptr == nullptr) {
NVTE_CHECK(dbeta_part->dptr == nullptr);
dgamma_part->dtype = ctype;
dgamma_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
dbeta_part->dtype = ctype;
dbeta_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
workspace->dtype = layer_norm::DType::kByte;
workspace->shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size };
return;
}
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) *
typeToSize(barrier->dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const NVTETensor scale, // 1
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv,
bool fp8_out) {
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
*reinterpret_cast<const Tensor*>(scale),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
fp8_out);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu),
*reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma),
reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma),
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier));
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#include "ln.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_tuned_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if ( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for ( int it = 0; it < ROWS_PER_CTA; it++ ) {
for ( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for ( int it = 0; it < ROWS_PER_CTA; it++ ) {
for ( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for ( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for ( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_tuned_kernel(BwdParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for ( uint32_t col = c, col_out = c_out;
col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for ( uint32_t row = warp; row < params.ctas_per_col;
row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE
+ Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for ( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if ( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_general_kernel(layer_norm::BwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP
+ warp_n * params.ctas_per_row * THREADS_PER_WARP
+ lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
for ( int cta_row = bidm * bdimm;
cta_row < params.rows;
cta_row += gdimm ) {
const int row = cta_row + warp_m;
const compute_t mu = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs = static_cast<const compute_t *>(params.rs)[row];
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij - mu);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec dx;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
if constexpr ( WARPS_M == 1 ) {
// Write out local weight grad contributions
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP+1];
// Reduce dz
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
__syncthreads();
#pragma unroll
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS;
it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) {
#pragma unroll
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dz_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
}
// Reduce dzy
__syncthreads();
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
if ( it != warp_m ) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
#pragma unroll
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS;
it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) {
#pragma unroll
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dzy_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
}
}
template<
typename weight_t,
typename compute_t,
uint32_t WARPS_M,
uint32_t WARPS_N,
uint32_t BYTES_PER_LDG,
uint32_t THREADS_PER_WARP
>
__global__ __launch_bounds__(WARPS_M * WARPS_N * THREADS_PER_WARP)
void ln_bwd_finalize_general_kernel(layer_norm::BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma, dbeta;
dgamma.clear();
dbeta.clear();
for ( int row = warp_m;
row < params.ctas_per_col && col < params.cols;
row += WARPS_M ) {
Cvec dgamma_part, dbeta_part;
dgamma_part.load_from_elts(params.dgamma_part,
row * params.cols + col,
params.cols - col);
dbeta_part.load_from_elts(params.dbeta_part,
row * params.cols + col,
params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
dbeta.data.elt[jt] += dbeta_part.data.elt[jt];
}
}
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP+1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) {
__syncthreads();
if ( warp_m < nrows ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt]
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt];
}
}
}
if ( warp_m == 0 && col < params.cols ) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
// Reduce dgamma within CTA
__syncthreads();
dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) {
__syncthreads();
if ( warp_m < nrows ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt]
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt];
}
}
}
if ( warp_m == 0 && col < params.cols ) {
Wvec dbeta_out;
vecs_shared[warp_m][warp_n][lane].to(dbeta_out);
dbeta_out.store_to_elts(params.dbeta, col, params.cols - col);
}
}
} // namespace layer_norm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "ln.h"
#include "../utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
using namespace transformer_engine::layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount
* ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if ( ctas_per_row == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>
(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>
(launch_params.params);
}
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(typename Kernel_traits::reduce_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
0,
stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL = (Kernel_traits::THREADS_PER_WARP
* WARPS_N_FINAL
* BYTES_PER_LDG_FINAL
/ sizeof(compute_t));
auto kernel_final = &ln_bwd_finalize_general_kernel<weight_t,
compute_t,
WARPS_M_FINAL,
WARPS_N_FINAL,
BYTES_PER_LDG_FINAL,
Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
// Create tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
// Create general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "ln.h"
#include "../utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
using namespace transformer_engine::layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount *
ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if ( ctas_per_row == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA,
Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
}
}
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(compute_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
0,
stream);
}
}
// Create tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16);
// Create general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16);
REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16);
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_
#include <cfloat>
#include <cstdio>
#include "ln.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_tuned_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t*>(params.scale);
}
compute_t amax = 0;
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if ( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if ( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
}
}
}
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_general_kernel(FwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP
+ warp_n * params.ctas_per_row * THREADS_PER_WARP
+ lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
Cvec beta[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Wvec gamma_in, beta_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
beta_in.load_from_elts(params.beta, col, params.cols - col);
gamma_in.to(gamma[it]);
beta_in.to(beta[it]);
}
// fp8 factors
compute_t scale;
if ( params.fp8_out ) {
scale = *reinterpret_cast<compute_t*>(params.scale);
}
compute_t amax = 0;
for ( int cta_row = bidm * bdimm;
cta_row < params.rows;
cta_row += gdimm ) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec x_in;
x_in.load_from_elts(params.x,
row * params.cols + col,
params.cols - col);
x_in.to(x[it]);
}
// Compute mean
compute_t mu = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
mu += x[it].data.elt[jt];
}
}
mu = reducer.allreduce(mu, sum) * rn;
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t diff = x[it].data.elt[jt] - mu;
sqsigma += diff * diff;
}
}
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if ( gidn == 0 && row < params.rows ) {
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
mu_ptr[row] = mu;
rs_ptr[row] = rs;
}
// Compute output
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
// Compute output values
Cvec z;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij;
}
// Apply fp8 factors
if ( params.fp8_out ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z,
row * params.cols + col,
params.cols - col);
}
}
// Finalize fp8 factors
if ( params.fp8_out ) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if ( threadIdx.x == 0 ) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
reciprocal<compute_t>(reinterpret_cast<compute_t*>(params.scale_inv), scale);
}
}
}
} // namespace layer_norm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
#include "../common.h"
#include "../utils.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace transformer_engine {
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert(static_cast<int>(ROWS_PER_CTA) <= static_cast<int>(Base::THREADS_PER_WARP));
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4,
"Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP,
"We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = transformer_engine::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename transformer_engine::TypeToVec2<compute_t>::Type;
using Reducer = transformer_engine::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = transformer_engine::Vec<input_t, NUM_ELTS>;
using Ovec = transformer_engine::Vec<output_t, NUM_ELTS>;
using Wvec = transformer_engine::Vec<weight_t, NUM_ELTS>;
using Cvec = transformer_engine::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements
// in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""This module provides predefined FP8 recipes."""
from __future__ import annotations
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass
class _FormatHelper(NamedTuple):
"""
Stores max FP8 values for fprop and bprop a `Format`.
"""
max_fwd: float
max_bwd: float
class Format(Enum):
"""
Supported FP8 formats.
Values
------
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
All FP8 tensors are in e5m2 format
HYBRID :
FP8 tensors in the forward pass are in e4m3 format,
FP8 tensors in the backward pass are in e5m2 format
"""
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
class _OverrideLinearPrecision(NamedTuple):
"""
Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
GEMMs in higher precision when using FP8.
"""
fprop: bool = False
dgrad: bool = False
wgrad: bool = False
@dataclass()
class DelayedScaling:
"""
Use the delayed scaling factor strategy.
Use scale factor from previous iteration,
recompute once every `interval`, and record
amax history of `amax_history_len` steps.
Parameters
----------
margin : int, default = 0
Margin for the scaling factor computation.
interval : int, default = 1
Controls how often the scaling factor is recomputed.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
amax_history_len : int, default = 1
The length of the amax history window used for
scaling factor computation.
amax_compute_algo : {'max', 'most_recent', Callable}, default = 'most_recent'
Algorithm used for choosing the `amax` value for the
scaling factor computation. There are 2 predefined
choices: `max` chooses the largest `amax` in the history
window, while `most_recent` always chooses the most recently
seen value. Alternatively, one may pass a function of the
signature:
.. code-block:: python
def amax_compute(amax_history: Tensor) -> Tensor
where `Tensor` is a framework tensor type.
scaling_factor_compute_algo : Callable, default = None
Algorithm used for computing the new scaling
factor based on the value of `amax`. It should
be a function of the signature:
.. code-block:: python
def scaling_factor_compute(amax: Tensor,
old_scaling_factor: Tensor,
fp8_max: Tensor,
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8.
Notes
-----
* By default (when `scaling_factor_compute_algo` is left as `None`) the scaling
factor is computed from the final `amax` value using the formula:
.. code-block:: python
FP8_MAX = maximum_representable_value(fp8_format)
exp = get_exponent(FP8_MAX / amax) - margin
new_scaling_factor = 2.0 ^ exp
* The scaling factor should always be a power of 2 to not introduce numerical
error during the conversion from FP8 to higher precision format.
"""
margin: int = 0
interval: int = 1
fp8_format: Format = Format.HYBRID
amax_history_len: int = 1
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent"
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert self.override_linear_precision in (
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include "common.h"
namespace transformer_engine {
size_t typeToSize(const transformer_engine::DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;
); // NOLINT(*)
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->dptr = dptr;
ret->shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
ret->dtype = static_cast<transformer_engine::DType>(dtype);
return ret;
}
void nvte_destroy_tensor(NVTETensor tensor) {
if (tensor == nullptr) return;
auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
delete t;
}
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>(reinterpret_cast<const transformer_engine::Tensor*>(tensor)->dtype);
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTEShape ret;
ret.data = t.shape.data();
ret.ndim = t.shape.size();
return ret;
}
void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
return t.dptr;
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include "../utils.cuh"
#include "../common.h"
namespace transformer_engine {
template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType>
inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
CType &max, // NOLINT(*)
const CType scale,
const bool valid_store) {
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = static_cast<CType>(in[i].data.elt[j]);
const T elt_o = T(scale * tmp);
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0);
max = fmaxf(fabsf(tmp), max);
}
if (full_tile || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
}
}
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel(const IType * const input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *scale_ptr;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(in[current_in ^ 1], out_trans, my_output_c_tile,
current_place, stride, max, scale, true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale);
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel_notaligned(const IType * const input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *scale_ptr;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
} else {
in[current_in][j].clear();
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs<false>(in[current_in ^ 1], out_trans, my_output_c_tile,
current_place, stride, max, scale, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale);
}
}
void cast_transpose(const Tensor &input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both C and T outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size;
constexpr int nvec_out = desired_store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_kernel<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_kernel<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr),
reinterpret_cast<const fp32 *>(scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr),
reinterpret_cast<const fp32 *>(scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include <type_traits>
#include "../utils.cuh"
#include "../common.h"
namespace transformer_engine {
template <bool full_tile, int nvec_in, int nvec_out,
typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
CType &max, // NOLINT(*)
const CType scale,
const int dbias_shfl_src_lane,
const bool valid_store) {
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
CVec step_dbias; step_dbias.clear();
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j];
const T elt_o = T(scale * tmp);
/* dbias: thread tile local accumulation */
step_dbias.data.elt[j] += tmp;
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0);
max = fmaxf(fabsf(tmp), max);
}
if (full_tile || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
}
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp
out_dbias.data.elt[j] += elt;
}
}
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
namespace {
template <typename IType, typename OType, typename CType>
struct CTDBiasParam {
using InputType = IType;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
};
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDGeluParam {
using InputType = IType;
using InputType2 = IType2;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *gelu_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
};
} // namespace
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
partial_dbias.clear();
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs_partial_dbias<true>(in[current_in ^ 1], out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
}
}
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
partial_dbias.clear();
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
} else {
in[current_in][j].clear();
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs_partial_dbias<false>(in[current_in ^ 1], out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
}
}
constexpr size_t reduce_dbias_num_threads = 256;
template<int nvec, typename ComputeType, typename OutputType>
__global__ void
__launch_bounds__(reduce_dbias_num_threads)
reduce_dbias_kernel(OutputType* const dbias_output,
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) return;
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base, 0);
}
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace,
const int nvec_out) {
const size_t row_length = cast_output.shape[1];
const size_t num_rows = cast_output.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->shape = {num_rows_partial_dbias, row_length};
workspace->dtype = DType::kFloat32;
}
template <typename InputType>
void reduce_dbias(const Tensor &workspace, Tensor *dbias,
const size_t row_length, const size_t num_rows, const int nvec_out,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows,
static_cast<size_t>(nvec_out *
THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks = DIVUP(row_length,
reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, InputType>
<<<reduce_dbias_num_blocks,
reduce_dbias_num_threads,
0,
stream>>>(
reinterpret_cast<InputType *>(dbias->dptr),
reinterpret_cast<const fp32 *>(workspace.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
}
void cast_transpose_dbias(const Tensor &input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both T and C outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated.");
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size;
constexpr int nvec_out = desired_store_size / otype_size;
if (workspace->dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<OutputType, nvec_out>);
constexpr size_t shared_size_dbias = cast_transpose_num_threads *
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasParam<InputType, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
}
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
}
namespace {
template <typename CType, typename IType>
__device__ inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
} // namespace
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_gelu_input_tile = param.gelu_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 gelu_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
partial_dbias.clear();
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
current_stride + my_place_in +
stride * (nvec_out + j));
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs_partial_dbias<true>(after_dgelu, out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
}
}
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_gelu_input_tile = param.gelu_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 gelu_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
partial_dbias.clear();
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
gelu_in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
current_stride + my_place_in +
stride * (nvec_out + j));
} else {
in[current_in][j].clear();
gelu_in[current_in][j].clear();
}
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) *
CType(in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs_partial_dbias<false>(after_dgelu, out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
}
}
void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2,
"T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape,
"Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both C and T outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(gelu_input.dptr != nullptr, "GeLU input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated.");
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.dtype == gelu_input.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.shape == gelu_input.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4;
constexpr int desired_store_size_dgelu = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size;
if (workspace->dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<OutputType, nvec_out>);
constexpr size_t shared_size_dbias = cast_transpose_num_threads *
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr);
param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dgelu_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
}
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace),
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
#include <cfloat>
#include "../utils.cuh"
#include "../common.h"
namespace transformer_engine {
template <int nvec_in, int nvec_out, typename IVec, typename OVec>
inline __device__ void transpose_regs(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in]) {
using T = typename OVec::type;
using U = typename IVec::type;
static_assert(std::is_same<T, U>::value, "Types of input and output to transpose must match!");
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_trans[j].data.elt[i] = in[i].data.elt[j]; // thread tile transpose
}
}
}
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
transpose_kernel(const IType * const input,
OType * const output_t,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs<nvec_in, nvec_out, IVec, OVec>(in[current_in ^ 1], out_trans);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
transpose_kernel_notaligned(const IType * const input,
OType * const output_t,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
} else {
in[current_in][j].clear();
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs<nvec_in, nvec_out, IVec, OVec>(in[current_in ^ 1], out_trans);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
}
void transpose(const Tensor &input,
Tensor *transposed_output,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "Output must have 2 dimensions.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.dtype == transposed_output->dtype, "Input and output type must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.dtype, Type,
constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
if (full_tile) {
cudaFuncSetAttribute(transpose_kernel<nvec_in, nvec_out, fp32, Type, Type>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
transpose_kernel<nvec_in, nvec_out, fp32, Type, Type>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>(
reinterpret_cast<const Type *>(input.dptr),
reinterpret_cast<Type *>(transposed_output->dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(transpose_kernel_notaligned<nvec_in, nvec_out, fp32, Type, Type>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
transpose_kernel_notaligned<nvec_in, nvec_out, fp32, Type, Type>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>(
reinterpret_cast<const Type *>(input.dptr),
reinterpret_cast<Type *>(transposed_output->dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream) {
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment