Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -33,14 +33,11 @@ extern "C" {
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
NVTETensor output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
......@@ -62,14 +59,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads,
const NVTETensor freqs, NVTETensor input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
NVTETensor input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
......@@ -90,14 +85,12 @@ void nvte_fused_rope_backward(const NVTETensor output_grads,
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input,
const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
......@@ -118,12 +111,12 @@ void nvte_fused_rope_thd_forward(const NVTETensor input,
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(
const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -39,20 +39,10 @@ extern "C" {
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream
);
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
......@@ -82,27 +72,14 @@ void nvte_cublas_gemm(const NVTETensor A,
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_GEMM_H_
......@@ -42,16 +42,9 @@ extern "C" {
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute LayerNorm with zero-centered gamma on the input.
......@@ -79,19 +72,11 @@ void nvte_layernorm_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm.
*
* This function computes the gradient of function:
......@@ -121,20 +106,14 @@ void nvte_layernorm1p_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
/*! \brief Compute backward of LayerNorm with zero-centered gamma.
*
......@@ -165,20 +144,14 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -44,18 +44,11 @@ extern "C" {
* \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
const NVTETensor scale_inv_mask,
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char* amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
*
......@@ -85,15 +78,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -43,15 +43,9 @@ extern "C" {
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm_fwd(const NVTETensor x,
const NVTETensor gamma,
const float epsilon,
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
/*! \brief Compute RMSNorm with zero-centered gamma on the input.
*
......@@ -79,15 +73,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const float epsilon,
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon,
NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
/*! \brief Compute backward of RMSNorm.
*
......@@ -118,18 +106,10 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm_bwd(const NVTETensor dz,
const NVTETensor x,
const NVTETensor rsigma,
const NVTETensor gamma,
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
/*! \brief Compute backward of RMSNorm with zero-centered gamma.
*
......@@ -160,18 +140,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz,
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_bwd(const NVTETensor dz,
const NVTETensor x,
const NVTETensor rsigma,
const NVTETensor gamma,
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier);
#ifdef __cplusplus
} // extern "C"
......
......@@ -7,8 +7,9 @@
#ifndef TRANSFORMER_ENGINE_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SOFTMAX_H_
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -22,13 +23,8 @@ extern "C" {
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results,
float scale_factor, cudaStream_t stream);
/*! \brief Compute the backward of the scaled softmax activation.
*
......@@ -42,14 +38,8 @@ void nvte_scaled_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results,
NVTETensor output_grads, float scale_factor, cudaStream_t stream);
/*! \brief Compute scaled masked softmax activation on the input.
*
......@@ -59,14 +49,9 @@ void nvte_scaled_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_masked_softmax_forward(
const NVTETensor input,
const NVTETensor mask,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask,
NVTETensor softmax_results, float scale_factor,
cudaStream_t stream);
/*! \brief Compute the backward of the scaled masked softmax activation.
*
......@@ -80,14 +65,9 @@ void nvte_scaled_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor softmax_results, NVTETensor output_grads,
float scale_factor, cudaStream_t stream);
/*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input.
*
......@@ -96,13 +76,9 @@ void nvte_scaled_masked_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_upper_triang_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
NVTETensor softmax_results, float scale_factor,
cudaStream_t stream);
/*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask.
*
......@@ -116,14 +92,10 @@ void nvte_scaled_upper_triang_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads, float scale_factor,
cudaStream_t stream);
/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
......@@ -132,13 +104,9 @@ void nvte_scaled_upper_triang_masked_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
NVTETensor softmax_results,
float scale_factor, cudaStream_t stream);
/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
......@@ -152,13 +120,10 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads, float scale_factor,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -11,8 +11,8 @@
#ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#include <stddef.h>
#include <cuda_runtime_api.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
......@@ -22,15 +22,15 @@ extern "C" {
* \brief TE datatype.
*/
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */
};
/*! \struct NVTEShape
......@@ -49,7 +49,7 @@ struct NVTEShape {
* to data of a given shape and type. It does not own the
* memory it points to.
*/
typedef void* NVTETensor;
typedef void *NVTETensor;
/*! \brief Create a new TE tensor.
*
......@@ -66,12 +66,8 @@ typedef void* NVTETensor;
*
* \return A new TE tensor.
*/
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype,
float *amax_dptr,
float *scale_dptr,
float *scale_inv_dptr);
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype,
float *amax_dptr, float *scale_dptr, float *scale_inv_dptr);
/*! \brief Destroy a TE tensor.
*
......@@ -144,11 +140,11 @@ struct NVTETensorPack {
/*! \brief Create `tensors` in NVTETensorPack.
*/
void nvte_tensor_pack_create(NVTETensorPack* pack);
void nvte_tensor_pack_create(NVTETensorPack *pack);
/*! \brief Destroy `tensors` in NVTETensorPack.
*/
void nvte_tensor_pack_destroy(NVTETensorPack* pack);
void nvte_tensor_pack_destroy(NVTETensorPack *pack);
#ifdef __cplusplus
} // extern "C"
......@@ -164,12 +160,12 @@ namespace transformer_engine {
* \brief TE datatype.
*/
enum class DType {
kByte = 0,
kInt32 = 1,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kByte = 0,
kInt32 = 1,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kNumTypes
......@@ -193,11 +189,10 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype),
amax_dptr, scale_dptr, scale_inv_dptr)) {}
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr)
: tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype), amax_dptr,
scale_dptr, scale_inv_dptr)) {}
/*! \brief Constructs new TensorWrapper.
*
......@@ -214,9 +209,9 @@ class TensorWrapper {
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype,
amax_dptr, scale_dptr, scale_inv_dptr) {}
float *scale_inv_dptr = nullptr)
: TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr,
scale_inv_dptr) {}
/*! \brief Constructs new empty TensorWrapper.
*
......@@ -225,11 +220,9 @@ class TensorWrapper {
TensorWrapper() : TensorWrapper(nullptr, std::vector<size_t>(), DType::kFloat32) {}
/*! \brief TensorWrapper destructor. */
~TensorWrapper() {
nvte_destroy_tensor(tensor_);
}
~TensorWrapper() { nvte_destroy_tensor(tensor_); }
TensorWrapper& operator=(const TensorWrapper &other) = delete;
TensorWrapper &operator=(const TensorWrapper &other) = delete;
TensorWrapper(const TensorWrapper &other) = delete;
/*! \brief Constructs new TensorWrapper from existing TensorWrapper.
......@@ -249,7 +242,7 @@ class TensorWrapper {
*
* \param[in,out] other The source of the data.
*/
TensorWrapper& operator=(TensorWrapper &&other) {
TensorWrapper &operator=(TensorWrapper &&other) {
if (this == &other) return *this;
nvte_destroy_tensor(tensor_);
tensor_ = other.tensor_;
......@@ -261,9 +254,7 @@ class TensorWrapper {
*
* \return NVTETensor held by this TensorWrapper.
*/
NVTETensor data() const noexcept {
return tensor_;
}
NVTETensor data() const noexcept { return tensor_; }
/*! \brief Get the shape of this TensorWrapper.
*
......
......@@ -28,10 +28,8 @@ extern "C" {
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, cudaStream_t stream);
/*! \brief Transpose the input.
*
......@@ -39,9 +37,7 @@ void nvte_cast_transpose(const NVTETensor input,
* \param[out] transposed_output Result of the transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaStream_t stream);
/*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension.
*
......@@ -61,11 +57,8 @@ void nvte_transpose(const NVTETensor input,
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension.
......@@ -84,11 +77,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_transpose_dbias(const NVTETensor input,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors.
*
......@@ -105,10 +95,8 @@ void nvte_fp8_transpose_dbias(const NVTETensor input,
* of tensors in input_list.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream);
/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally,
......@@ -131,50 +119,29 @@ void nvte_multi_cast_transpose(size_t num_tensors,
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
*
......@@ -189,38 +156,28 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dqgeglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -8,11 +8,12 @@
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
#include <vector>
#include <unordered_map>
#include <vector>
#include "../common.h"
......@@ -21,113 +22,107 @@ namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
template <typename Params>
struct LaunchParams {
size_t workspace_bytes;
size_t barrier_size;
int multiprocessorCount;
cudaStream_t stream;
int multiprocessorCount;
cudaStream_t stream;
Params params;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
, zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Size of CTA group.
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
ParamsBase()
: ctas_per_col(0),
rows(0),
cols(0),
x(nullptr),
mu(nullptr),
rs(nullptr),
gamma(nullptr),
workspace(nullptr),
barrier(nullptr),
zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Size of CTA group.
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
, fp8_out(false) {}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
// Scaling factor
void *scale;
// AMax output
void *amax;
// Whether to compute scale and amax
bool fp8_out;
FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
// Scaling factor
void *scale;
// AMax output
void *amax;
// Whether to compute scale and amax
bool fp8_out;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
BwdParams()
: ParamsBase(),
dz(nullptr),
dbeta_part(nullptr),
dgamma_part(nullptr),
dx(nullptr),
dbeta(nullptr),
dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
......@@ -141,96 +136,96 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template <typename T>
struct TypeId {};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
template <>
struct TypeId<fp16> {
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
template <>
struct TypeId<bf16> {
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
template <>
struct TypeId<fp32> {
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
template <>
struct TypeId<fp8e4m3> {
constexpr static uint32_t Value = 3;
};
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
template <typename T, int S>
struct Type2Key {
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template <typename T>
struct WeightType2Key : public Type2Key<T, 0> {};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template <typename T>
struct InputType2Key : public Type2Key<T, 2> {};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template <typename T>
struct OutputType2Key : public Type2Key<T, 4> {};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
template <typename T>
struct ComputeType2Key : public Type2Key<T, 6> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
template <typename W, typename I, typename O, typename C>
struct Types2Key {
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size) {
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar{
explicit FwdTunedRegistrar(FwdFunction f){
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({ key, f });
}
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar{
explicit FwdGeneralRegistrar(FwdFunction f){
uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar{
explicit BwdTunedRegistrar(BwdFunction f){
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({ key, f });
}
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar{
explicit BwdGeneralRegistrar(BwdFunction f){
uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -9,8 +9,8 @@
#include <cstdint>
#include <vector>
#include "ln.h"
#include "../common.h"
#include "ln.h"
/*
......@@ -46,500 +46,411 @@ BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(DType dtype) {
if ( dtype == DType::kFloat16 ) {
return TypeId<fp16>::Value;
} else if ( dtype == DType::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if ( dtype == DType::kFloat32 ) {
return TypeId<fp32>::Value;
} else if ( dtype == DType::kFloat8E4M3 ) {
return TypeId<fp8e4m3>::Value;
} else {
NVTE_ERROR("Type not supported.");
}
if (dtype == DType::kFloat16) {
return TypeId<fp16>::Value;
} else if (dtype == DType::kBFloat16) {
return TypeId<bf16>::Value;
} else if (dtype == DType::kFloat32) {
return TypeId<fp32>::Value;
} else if (dtype == DType::kFloat8E4M3) {
return TypeId<fp8e4m3>::Value;
} else {
NVTE_ERROR("Type not supported.");
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) |
(get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) |
(get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
const layer_norm::FwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.beta)
&& is_aligned(params.z)
&& layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::FwdParams& params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void* ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) &&
is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
const layer_norm::BwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.dz)
&& is_aligned(params.dx)
&& is_aligned(params.dbeta)
&& is_aligned(params.dgamma)
&& is_aligned(params.dbeta_part)
&& is_aligned(params.dgamma_part)
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::BwdParams& params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void* ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) &&
is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) &&
is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) &&
layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (auto s : shape) {
ret *= s;
}
return ret;
size_t product(const std::vector<size_t>& shape) {
size_t ret = 1;
for (auto s : shape) {
ret *= s;
}
return ret;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const float epsilon,
Tensor* z,
Tensor* mu,
Tensor* rsigma,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier,
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream,
const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
const bool zero_centered_gamma) {
const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(rsigma->data.dtype == ctype);
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu->data.dptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = beta.data.dptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = { launch_params.workspace_bytes };
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size };
return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(z->amax.shape) *
typeToSize(z->amax.dtype), stream);
}
if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
layer_norm::FwdParams& params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu->data.dptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = beta.data.dptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void layernorm_bwd(const Tensor& dz,
const Tensor& x,
const Tensor& mu,
const Tensor& rsigma,
const Tensor& gamma,
Tensor* dx,
Tensor* dgamma,
Tensor* dbeta,
Tensor* dgamma_part,
Tensor* dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
auto rows = x.data.shape[0];
auto cols = x.data.shape[1];
auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu.data.dptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
NVTE_CHECK(dbeta_part->data.dptr == nullptr);
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
dbeta_part->data.dtype = ctype;
dbeta_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = { launch_params.workspace_bytes };
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size };
return;
} else {
NVTE_CHECK(dbeta_part->data.dptr != nullptr);
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
NVTE_CHECK(dbeta_part->data.dtype == ctype);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma,
const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta,
Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream,
const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
auto rows = x.data.shape[0];
auto cols = x.data.shape[1];
auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
layer_norm::BwdParams& params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = mu.data.dptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
NVTE_CHECK(dbeta_part->data.dptr == nullptr);
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
dbeta_part->data.dtype = ctype;
dbeta_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(dbeta_part->data.dptr != nullptr);
auto pdw_shape =
std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
NVTE_CHECK(dbeta_part->data.dtype == ctype);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
false);
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), false);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu),
*reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma),
reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma),
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
false);
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), false);
}
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), true);
}
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu),
*reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma),
reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma),
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), true);
}
......@@ -7,605 +7,570 @@
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#include "ln.h"
#include "../utils.cuh"
#include "ln.h"
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_tuned_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel(
layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
const compute_t x_tmp = x[it].data.elt[jt];
const compute_t y_tmp = rs_r * (x_tmp - mu_r);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if ( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for ( int it = 0; it < ROWS_PER_CTA; it++ ) {
for ( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for ( int it = 0; it < ROWS_PER_CTA; it++ ) {
for ( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for ( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
const compute_t x_tmp = x[it].data.elt[jt];
const compute_t y_tmp = rs_r * (x_tmp - mu_r);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for ( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
}
} // end: grid stride loop
if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_tuned_kernel(BwdParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for ( uint32_t col = c, col_out = c_out;
col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for ( uint32_t row = warp; row < params.ctas_per_col;
row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE
+ Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for ( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if ( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_general_kernel(layer_norm::BwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP
+ warp_n * params.ctas_per_row * THREADS_PER_WARP
+ lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel(
BwdParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
for ( int cta_row = bidm * bdimm;
cta_row < params.rows;
cta_row += gdimm ) {
const int row = cta_row + warp_m;
compute_t mu = 0.f;
compute_t rs = 0.f;
if ( row < params.rows ) {
mu = static_cast<const compute_t *>(params.mu)[row];
rs = static_cast<const compute_t *>(params.rs)[row];
}
void *smem_gamma = smem_;
void *smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec dx;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void *smem_beta_out =
&smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
if constexpr ( WARPS_M == 1 ) {
// Write out local weight grad contributions
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP+1];
// Reduce dz
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
__syncthreads();
#pragma unroll
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS;
it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) {
#pragma unroll
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dz_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
}
// All writes done.
__syncthreads();
// Reduce dzy
__syncthreads();
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
if ( it != warp_m ) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
#pragma unroll
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS;
it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) {
#pragma unroll
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dzy_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
// Pack and store: 2-wide stores with half the threads.
if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
template<
typename weight_t,
typename compute_t,
uint32_t WARPS_M,
uint32_t WARPS_N,
uint32_t BYTES_PER_LDG,
uint32_t THREADS_PER_WARP
>
__global__ __launch_bounds__(WARPS_M * WARPS_N * THREADS_PER_WARP)
void ln_bwd_finalize_general_kernel(layer_norm::BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma, dbeta;
dgamma.clear();
dbeta.clear();
for ( int row = warp_m;
row < params.ctas_per_col && col < params.cols;
row += WARPS_M ) {
Cvec dgamma_part, dbeta_part;
dgamma_part.load_from_elts(params.dgamma_part,
row * params.cols + col,
params.cols - col);
dbeta_part.load_from_elts(params.dbeta_part,
row * params.cols + col,
params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
dbeta.data.elt[jt] += dbeta_part.data.elt[jt];
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel(
layer_norm::BwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t mu = 0.f;
compute_t rs = 0.f;
if (row < params.rows) {
mu = static_cast<const compute_t *>(params.mu)[row];
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
if constexpr (WARPS_M == 1) {
// Write out local weight grad contributions
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col);
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
// Reduce dz
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
__syncthreads();
#pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dz_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col);
}
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP+1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) {
__syncthreads();
if ( warp_m < nrows ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt]
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt];
}
// Reduce dzy
__syncthreads();
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
if (it != warp_m) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
#pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
}
if ( warp_m == 0 && col < params.cols ) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
}
template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__
__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel(
layer_norm::BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma, dbeta;
dgamma.clear();
dbeta.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part, dbeta_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
dbeta_part.load_from_elts(params.dbeta_part, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
dbeta.data.elt[jt] += dbeta_part.data.elt[jt];
}
}
// Reduce dgamma within CTA
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) {
__syncthreads();
if ( warp_m < nrows ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt]
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt];
}
}
if (warp_m < nrows) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
}
if ( warp_m == 0 && col < params.cols ) {
Wvec dbeta_out;
vecs_shared[warp_m][warp_n][lane].to(dbeta_out);
dbeta_out.store_to_elts(params.dbeta, col, params.cols - col);
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
// Reduce dgamma within CTA
__syncthreads();
dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
if (warp_m < nrows) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
}
}
if (warp_m == 0 && col < params.cols) {
Wvec dbeta_out;
vecs_shared[warp_m][warp_n][lane].to(dbeta_out);
dbeta_out.store_to_elts(params.dbeta, col, params.cols - col);
}
}
} // namespace layer_norm
......
......@@ -5,233 +5,154 @@
************************************************************************/
#include "ln.h"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include "ln_kernel_traits.h"
using namespace transformer_engine::layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount
* ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &ln_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if ( ctas_per_row == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>
(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>
(launch_params.params);
}
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(typename Kernel_traits::reduce_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
return;
}
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
0,
stream);
}
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t,
output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL = (Kernel_traits::THREADS_PER_WARP
* WARPS_N_FINAL
* BYTES_PER_LDG_FINAL
/ sizeof(compute_t));
auto kernel_final = &ln_bwd_finalize_general_kernel<weight_t,
compute_t,
WARPS_M_FINAL,
WARPS_N_FINAL,
BYTES_PER_LDG_FINAL,
Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &ln_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&ln_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -252,9 +173,9 @@ REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
......@@ -263,11 +184,11 @@ REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
......@@ -318,16 +239,16 @@ REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
......@@ -336,9 +257,9 @@ REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
......
......@@ -5,176 +5,128 @@
************************************************************************/
#include "ln.h"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
#include "ln_kernel_traits.h"
using namespace transformer_engine::layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount *
ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if ( ctas_per_row == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA,
Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2;
}
}
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(compute_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel),
grid,
block,
reinterpret_cast<void **>(&params_),
0,
stream);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
}
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG) \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG>(launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -187,21 +139,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
......@@ -213,21 +165,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
......@@ -239,21 +191,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
......@@ -295,11 +247,11 @@ REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
......@@ -337,17 +289,17 @@ REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
......@@ -373,17 +325,17 @@ REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
......
......@@ -9,306 +9,294 @@
#include <cfloat>
#include <cstdio>
#include "ln.h"
#include "../utils.cuh"
#include "ln.h"
namespace transformer_engine {
namespace layer_norm {
using namespace transformer_engine;
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_tuned_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for ( int it = 0; it < LDGS; ++it ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; ++it) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale = 1.f;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
x[it].load_from(params.x, idx);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
compute_t scale = 1.f;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t*>(params.scale);
if (bidn == 0 && warp_n == 0 && lane == 0) {
mu_ptr[row] = mu;
}
compute_t amax = 0;
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if (bidn == 0 && warp_n == 0 && lane == 0) {
rs_ptr[row] = rs;
}
if ( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if ( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
}
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_general_kernel(FwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP
+ warp_n * params.ctas_per_row * THREADS_PER_WARP
+ lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
Cvec beta[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
++it, col += gdimn * NUM_ELTS ) {
Wvec gamma_in, beta_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
beta_in.load_from_elts(params.beta, col, params.cols - col);
gamma_in.to(gamma[it]);
beta_in.to(beta[it]);
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel(
FwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
Cvec beta[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
++it, col += gdimn * NUM_ELTS) {
Wvec gamma_in, beta_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
beta_in.load_from_elts(params.beta, col, params.cols - col);
gamma_in.to(gamma[it]);
beta_in.to(beta[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
}
// fp8 factors
compute_t scale;
if ( params.fp8_out ) {
scale = *reinterpret_cast<compute_t*>(params.scale);
// Compute mean
compute_t mu = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu += x[it].data.elt[jt];
}
}
compute_t amax = 0;
for ( int cta_row = bidm * bdimm;
cta_row < params.rows;
cta_row += gdimm ) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec x_in;
x_in.load_from_elts(params.x,
row * params.cols + col,
params.cols - col);
x_in.to(x[it]);
mu = reducer.allreduce(mu, sum) * rn;
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t diff = x[it].data.elt[jt] - mu;
sqsigma += diff * diff;
}
}
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if (gidn == 0 && row < params.rows) {
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
mu_ptr[row] = mu;
rs_ptr[row] = rs;
}
// Compute mean
compute_t mu = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
mu += x[it].data.elt[jt];
}
}
mu = reducer.allreduce(mu, sum) * rn;
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t diff = x[it].data.elt[jt] - mu;
sqsigma += diff * diff;
}
}
// Compute output
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
// Compute output values
Cvec z;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if ( gidn == 0 && row < params.rows ) {
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
mu_ptr[row] = mu;
rs_ptr[row] = rs;
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij;
}
// Apply fp8 factors
if (params.fp8_out) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Compute output
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
// Compute output values
Cvec z;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij;
}
// Apply fp8 factors
if ( params.fp8_out ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z,
row * params.cols + col,
params.cols - col);
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
}
// Finalize fp8 factors
if ( params.fp8_out ) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if ( threadIdx.x == 0 ) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax);
}
}
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
}
} // namespace layer_norm
......
......@@ -14,154 +14,119 @@
namespace transformer_engine {
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_,
compute_t_, index_t_, THREADS_PER_CTA_> >
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert(static_cast<int>(ROWS_PER_CTA) <= static_cast<int>(Base::THREADS_PER_WARP));
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4,
"Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP,
"We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = transformer_engine::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert(static_cast<int>(ROWS_PER_CTA) <= static_cast<int>(Base::THREADS_PER_WARP));
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4,
"Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP,
"We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = transformer_engine::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename Base =
Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_,
WARPS_M_ * WARPS_N_ * THREADS_PER_WARP> >
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename transformer_engine::TypeToVec2<compute_t>::Type;
using Reducer = transformer_engine::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = transformer_engine::Vec<input_t, NUM_ELTS>;
using Ovec = transformer_engine::Vec<output_t, NUM_ELTS>;
using Wvec = transformer_engine::Vec<weight_t, NUM_ELTS>;
using Cvec = transformer_engine::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements
// in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename transformer_engine::TypeToVec2<compute_t>::Type;
using Reducer = transformer_engine::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = transformer_engine::Vec<input_t, NUM_ELTS>;
using Ovec = transformer_engine::Vec<output_t, NUM_ELTS>;
using Wvec = transformer_engine::Vec<weight_t, NUM_ELTS>;
using Cvec = transformer_engine::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements
// in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -7,19 +7,16 @@
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
#include <string>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
explicit NVTXWrapper(const std::string &name) { nvtxRangePush(name.c_str()); }
~NVTXWrapper() {
nvtxRangePop();
}
~NVTXWrapper() { nvtxRangePop(); }
};
} // namespace transformer_engine::nvtx
......
......@@ -7,12 +7,12 @@
#include <transformer_engine/recipe.h>
#include <cmath>
#include <string>
#include <limits>
#include <string>
#include "../common.h"
#include "../util/logging.h"
#include "../util/cuda_runtime.h"
#include "../util/logging.h"
namespace transformer_engine {
namespace delayed_scaling_recipe {
......@@ -24,18 +24,19 @@ enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };
const char* dtype_name(DType dtype) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
return TypeInfo<Type>::name;
); // NOLINT(*)
return TypeInfo<Type>::name;); // NOLINT(*)
return "";
}
// Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) {
switch (dtype) {
case DType::kFloat8E4M3: return 448;
case DType::kFloat8E5M2: return 57344;
default:
NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
case DType::kFloat8E4M3:
return 448;
case DType::kFloat8E5M2:
return 57344;
default:
NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
}
return 0;
}
......@@ -58,12 +59,12 @@ struct OtherParams {
#if CUDART_VERSION >= 12010
constexpr size_t max_constant_memory_per_kernel = 32768;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
constexpr size_t AMAX_PARAMS_LIMIT =
(max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#else
constexpr size_t max_constant_memory_per_kernel = 4096;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
constexpr size_t AMAX_PARAMS_LIMIT =
(max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#endif
struct AmaxParams {
......@@ -82,17 +83,10 @@ constexpr size_t bsize = 256;
* Grid dims: num_scales x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel(const float* amax_history_ptr,
const float* scale_ptr,
const float* scale_inv_ptr,
const unsigned char* scale_inv_mask_ptr,
float* updated_amax_history_ptr,
float* updated_scale_ptr,
float* updated_scale_inv_ptr,
size_t amax_history_length,
size_t amax_history_stride,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr,
const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr,
float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length,
size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) {
const size_t tid = threadIdx.x;
const size_t bid = blockIdx.x;
......@@ -109,22 +103,21 @@ kernel(const float* amax_history_ptr,
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // In case roll is in-place
if (i < length) {
updated_amax_history[i*stride] = (i > 0) ? a : 0;
updated_amax_history[i * stride] = (i > 0) ? a : 0;
}
}
// Compute amax to use for scaling factor
switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX:
{
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX: {
__shared__ float shared_amax[bsize];
shared_amax[tid] = amax;
__syncthreads();
......@@ -136,10 +129,9 @@ kernel(const float* amax_history_ptr,
__syncthreads();
}
amax = shared_amax[tid];
}
break;
default:
amax = 0;
} break;
default:
amax = 0;
}
}
......@@ -157,7 +149,7 @@ kernel(const float* amax_history_ptr,
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
scale = std::numeric_limits<float>::max();
}
updated_scale_ptr[bid] = scale;
......@@ -179,12 +171,8 @@ kernel(const float* amax_history_ptr,
* Grid dims: num_tensors x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel_bulk(
float* amax_reduction_buffer,
AmaxParams p,
size_t amax_history_length,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length,
AmaxComputeAlgo amax_compute_algo, float scaled_max) {
const size_t bid = blockIdx.x;
const size_t tid = threadIdx.x;
const int num_scale = p.param[bid].num_scale;
......@@ -201,32 +189,32 @@ kernel_bulk(
// Roll amax history
const auto& length = amax_history_length;
const auto& stride = p.param[bid].num_scale;
auto* amax_history = p.param[bid].amax_history+count;
const auto last_amax = ((amax_reduction_buffer != nullptr)
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ?
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0];
auto* amax_history = p.param[bid].amax_history + count;
const auto last_amax = ((amax_reduction_buffer != nullptr) &&
(amax_reduction_buffer[offset_in_buffer + count] != 0.0f))
? amax_reduction_buffer[offset_in_buffer + count]
: amax_history[0];
if (last_amax != 0.0f) {
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // Inplace roll
if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0;
amax_history[i * stride] = (i > 0) ? a : 0;
}
}
}
// Compute amax to use for scaling factor
switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX:
{
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX: {
__shared__ float shared_amax[bsize];
shared_amax[tid] = amax;
__syncthreads();
......@@ -238,10 +226,9 @@ kernel_bulk(
__syncthreads();
}
amax = shared_amax[tid];
}
break;
default:
amax = 0;
} break;
default:
amax = 0;
}
}
......@@ -269,7 +256,7 @@ kernel_bulk(
// amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do.
if (isinf(scale)) {
scale = std::numeric_limits<float>::max();
scale = std::numeric_limits<float>::max();
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
......@@ -281,24 +268,17 @@ kernel_bulk(
} // namespace
void amax_and_scale_update(const Tensor &amax_history,
const Tensor &scale,
const Tensor &scale_inv,
const Tensor &scale_inv_mask,
Tensor *updated_amax_history_,
Tensor *updated_scale_,
Tensor *updated_scale_inv_,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv,
const Tensor& scale_inv_mask, Tensor* updated_amax_history_,
Tensor* updated_scale_, Tensor* updated_scale_inv_,
const std::string& amax_compute_algo, DType fp8_dtype, float margin,
cudaStream_t stream) {
auto& updated_amax_history = *updated_amax_history_;
auto& updated_scale = *updated_scale_;
auto& updated_scale_inv = *updated_scale_inv_;
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
auto numel = [](const Tensor& tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
......@@ -307,48 +287,40 @@ void amax_and_scale_update(const Tensor &amax_history,
};
// Check tensors
NVTE_CHECK(amax_history.data.shape.size() == 2,
"Found ", amax_history.data.shape.size(), " dims");
NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
" dims");
const size_t amax_history_length = amax_history.data.shape[0];
const size_t num_scales = amax_history.data.shape[1];
NVTE_CHECK(amax_history.data.dtype == DType::kFloat32,
"Found ", dtype_name(amax_history.data.dtype), ".");
NVTE_CHECK(numel(scale) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(scale), ".");
NVTE_CHECK(scale.data.dtype == DType::kFloat32,
"Found ", dtype_name(scale.data.dtype), ".");
NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
dtype_name(amax_history.data.dtype), ".");
NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
numel(scale), ".");
NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
if (scale_inv_mask.data.dptr != nullptr) {
NVTE_CHECK(numel(scale_inv) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv), ".");
NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
numel(scale_inv), ".");
NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32);
NVTE_CHECK(numel(scale_inv_mask) == num_scales,
"Expected ", num_scales, " elements, ",
NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv_mask), ".");
NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte,
"Found ", dtype_name(scale_inv_mask.data.dtype), ".");
NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ",
dtype_name(scale_inv_mask.data.dtype), ".");
}
NVTE_CHECK(updated_amax_history.data.shape.size() == 2,
"Found ", updated_amax_history.data.shape.size(), " dims.");
NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length,
"Expected ", amax_history_length, ", ",
"but found ", updated_amax_history.data.shape[0]);
NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales,
"Expected ", num_scales, ", ",
NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
updated_amax_history.data.shape.size(), " dims.");
NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]);
NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ",
"but found ", updated_amax_history.data.shape[1]);
NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_amax_history.data.dtype), ".");
NVTE_CHECK(numel(updated_scale) == num_scales,
"Expected ", num_scales, " elements, ",
NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_amax_history.data.dtype), ".");
NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale), ".");
NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_scale.data.dtype), ".");
NVTE_CHECK(numel(updated_scale_inv) == num_scales,
"Expected ", num_scales, " elements, ",
NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_scale.data.dtype), ".");
NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale_inv), ".");
NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_scale_inv.data.dtype), ".");
NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_scale_inv.data.dtype), ".");
// amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
......@@ -366,31 +338,23 @@ void amax_and_scale_update(const Tensor &amax_history,
// Launch CUDA kernel
constexpr size_t block_size = amax_and_scale_update_impl::bsize;
const size_t grid_size = num_scales;
amax_and_scale_update_impl::kernel
<<<grid_size, block_size, 0, stream>>>(
static_cast<const float*>(amax_history.data.dptr),
static_cast<const float*>(scale.data.dptr),
amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
static_cast<const float*>(scale_inv.data.dptr),
static_cast<const unsigned char*>(scale_inv_mask.data.dptr),
static_cast<float*>(updated_amax_history.data.dptr),
static_cast<float*>(updated_scale.data.dptr),
static_cast<float*>(updated_scale_inv.data.dptr),
amax_history_length,
num_scales,
amax_compute_algo_,
scaled_max);
static_cast<float*>(updated_scale_inv.data.dptr), amax_history_length, num_scales,
amax_compute_algo_, scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
std::vector<Tensor*> amax_histories,
std::vector<Tensor*> scales,
std::vector<Tensor*> scale_invs,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
cudaStream_t stream) {
const std::string& amax_compute_algo, DType fp8_dtype,
float margin, cudaStream_t stream) {
using namespace transformer_engine;
// amax value to use for updating scaling factor
......@@ -407,7 +371,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Number of elements in tensor
auto numel = [] (const Tensor *tensor) -> size_t {
auto numel = [](const Tensor* tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor->data.shape) {
acc *= dim;
......@@ -418,7 +382,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
// Number of tensors in the bulk
const size_t num_tensors = amax_histories.size();
size_t num_remaining_tensors = num_tensors;
const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT;
const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT;
size_t amax_history_length = 0;
if (num_tensors > 0) {
amax_history_length = amax_histories[0]->data.shape[0];
......@@ -429,27 +393,26 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
AmaxParams p;
for (int iter = 0; iter < num_kernels; iter++) {
size_t kernel_num_scales = 0;
size_t kernel_num_tensors = (iter == (num_kernels - 1))
? num_remaining_tensors: AMAX_PARAMS_LIMIT;
size_t kernel_num_tensors =
(iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT;
for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
size_t i = iter * AMAX_PARAMS_LIMIT + pi;
// Check tensors
int num_scale = amax_histories[i]->data.shape[1];
NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2,
"Found ", amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale,
"Expected ", amax_history_length * num_scale, " elements, ",
"but found ", numel(amax_histories[i]), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1,
"Found ", scales[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(scales[i]) == num_scale,
"Expected ", num_scale, " elements, ",
"Found ", numel(scales[i]), ".");
NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ",
dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ",
amax_history_length * num_scale, " elements, ", "but found ",
numel(amax_histories[i]), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
" dims");
NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ",
numel(scales[i]), ".");
// amax parameters
kernel_num_scales += num_scale;
......@@ -462,13 +425,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
// Launch CUDA kernel
size_t grid_size = kernel_num_tensors;
const size_t block_size = amax_and_scale_update_impl::bsize;
amax_and_scale_update_impl::kernel_bulk
<<<grid_size, block_size, 0, stream>>>(
amax_buffer,
p,
amax_history_length,
amax_compute_algo_,
scaled_max);
amax_and_scale_update_impl::kernel_bulk<<<grid_size, block_size, 0, stream>>>(
amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
// shift amax buffer pointer
......@@ -482,44 +440,25 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
} // namespace delayed_scaling_recipe
} // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
const NVTETensor scale_inv_mask,
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history),
*reinterpret_cast<const Tensor*>(scale),
*reinterpret_cast<const Tensor*>(scale_inv),
*reinterpret_cast<const Tensor*>(scale_inv_mask),
reinterpret_cast<Tensor*>(updated_amax_history),
reinterpret_cast<Tensor*>(updated_scale),
reinterpret_cast<Tensor*>(updated_scale_inv),
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
*reinterpret_cast<const Tensor*>(scale_inv), *reinterpret_cast<const Tensor*>(scale_inv_mask),
reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
reinterpret_cast<Tensor*>(updated_scale_inv), amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
}
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
......@@ -530,12 +469,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer),
t_amax_histories,
t_scales,
t_scale_invs,
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
t_scale_invs, amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
}
......@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include <map>
#include <stdexcept>
......@@ -15,7 +16,6 @@
#include <vector>
#include "../common.h"
#include "../layer_norm/ln.h"
namespace transformer_engine {
......@@ -47,40 +47,40 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f});
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
}
};
} // namespace rmsnorm
......
......@@ -4,14 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/rmsnorm.h"
#include <cstdint>
#include <numeric>
#include <vector>
#include "rmsnorm.h"
#include "../common.h"
#include "rmsnorm.h"
#include "transformer_engine/rmsnorm.h"
/*
......@@ -49,85 +48,70 @@ BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS;
FwdFunction &get_fwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::FwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.z)
&& FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return FWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
}
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction &get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
const layer_norm::BwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.dz)
&& is_aligned(params.dx)
&& is_aligned(params.dgamma)
&& is_aligned(params.dgamma_part)
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) &&
is_aligned(params.dgamma) && is_aligned(params.dgamma_part) &&
layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key);
}
// Pick general kernel
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
if (BWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("BWD: Unsupported types.");
}
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
}
// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) {
return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
}
} // namespace rmsnorm
......@@ -137,213 +121,211 @@ inline size_t product(const std::vector<size_t> &shape) {
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) {
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = nullptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype);
rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;
// Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = nullptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
}
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
const auto rows = x.data.shape[0];
const auto cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = nullptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
using namespace transformer_engine;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
const auto rows = x.data.shape[0];
const auto cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params;
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data.dptr;
params.mu = nullptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = nullptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher.
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) {
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
auto pdw_shape =
std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
}
} // namespace transformer_engine
......@@ -364,18 +346,15 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
false);
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), false);
}
void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size
......@@ -394,9 +373,8 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part,
cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_bwd);
using namespace transformer_engine;
......@@ -404,6 +382,5 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
true);
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), true);
}
......@@ -16,466 +16,462 @@ using namespace transformer_engine;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdyy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdyy_local += dy_tmp * y_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
}
}
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
}
}
reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c;
if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
}
}
template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel(
BwdParams params) {
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dgamma_part;
dgamma_part.load_from(params.dgamma_part, idx);
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dgamma_part;
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
}
}
for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
}
}
void *smem_gamma = smem_;
void *smem_gamma = smem_;
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dgamma_local.store_to(smem_gamma, write_idx);
__syncthreads();
__syncthreads();
// It would be probably safe to reuse the first row of smem_gamma
void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
// It would be probably safe to reuse the first row of smem_gamma
void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
// More than one iter iff ROWS_PER_CTA < 32.
for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
// More than one iter iff ROWS_PER_CTA < 32.
for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dgamma_local.load_from(smem_gamma, read_idx);
}
// Load gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
compute_t g_i = dgamma_local.data.elt[it];
g_i = reducer.allreduce(g_i, sum);
for (int it = 0; it < NUM_ELT; it++) {
compute_t g_i = dgamma_local.data.elt[it];
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
}
dgamma_local.data.elt[it] = g_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
}
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
}
}
// All writes done.
__syncthreads();
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dgamma_vec2;
Vec<dst_t, NUM_ELT> dgamma_out2;
// Pack and store: 2-wide stores with half the threads.
if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dgamma_vec2;
Vec<dst_t, NUM_ELT> dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dgamma_vec2.load_from(smem_gamma_out, lane);
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
dgamma_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
}
for (int it = 0; it < NUM_ELT; it++) {
dgamma_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
}
}
}
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn =
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for weight grads
Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t rs = 0.f;
if (row < params.rows) {
rs = static_cast<const compute_t *>(params.rs)[row];
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t rs = 0.f;
if (row < params.rows) {
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn;
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
if constexpr (WARPS_M == 1) {
if constexpr (WARPS_M == 1) {
// Write out local weight grad contributions
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col,
params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
// Reduce dzy
__syncthreads();
// Reduce dzy
__syncthreads();
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
if (it != warp_m) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
if (it != warp_m) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
}
__syncthreads();
#pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col,
params.cols - col);
for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
}
}
}
template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__ __launch_bounds__(
WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma;
dgamma.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma;
dgamma.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
}
for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
}
}
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
// Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
if (warp_m < nrows) {
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
if (warp_m < nrows) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
}
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
}
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
}
} // namespace rmsnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment