Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -33,14 +33,11 @@ extern "C" { ...@@ -33,14 +33,11 @@ extern "C" {
* \param[in] o_stride_d Stride of the d dimension of output. * \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
NVTETensor output, const int s, const int b, const int s, const int b, const int h, const int d, const int d2,
const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h,
const int stride_s, const int stride_b, const int stride_d, const int o_stride_s, const int o_stride_b,
const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d, cudaStream_t stream);
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
...@@ -62,14 +59,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, ...@@ -62,14 +59,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
* \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_backward(const NVTETensor output_grads, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
const NVTETensor freqs, NVTETensor input_grads, NVTETensor input_grads, const int s, const int b, const int h,
const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b,
const int d, const int d2, const int stride_s, const int stride_h, const int stride_d, const int o_stride_s,
const int stride_b, const int stride_h, const int o_stride_b, const int o_stride_h, const int o_stride_d,
const int stride_d, const int o_stride_s, cudaStream_t stream);
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format. /*! \brief Apply rotary positional embedding to the input tensor in thd format.
* *
...@@ -90,14 +85,12 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, ...@@ -90,14 +85,12 @@ void nvte_fused_rope_backward(const NVTETensor output_grads,
* \param[in] o_stride_d Stride of the d dimension of output. * \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_thd_forward(const NVTETensor input, void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor output, const int max_s,
const NVTETensor freqs, NVTETensor output, const int b, const int h, const int d, const int d2,
const int max_s, const int b, const int h, const int stride_t, const int stride_h, const int stride_d,
const int d, const int d2, const int stride_t, const int o_stride_t, const int o_stride_h, const int o_stride_d,
const int stride_h, const int stride_d, cudaStream_t stream);
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format. /*! \brief Compute the backward of the fused rope in thd format.
* *
...@@ -118,12 +111,12 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, ...@@ -118,12 +111,12 @@ void nvte_fused_rope_thd_forward(const NVTETensor input,
* \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_thd_backward( void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const NVTETensor freqs, NVTETensor input_grads, const int max_s, const int b, const int h, const int d, const int d2,
const int b, const int h, const int d, const int d2, const int stride_t, const int stride_t, const int stride_h, const int stride_d,
const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_t, const int o_stride_h, const int o_stride_d,
const int o_stride_h, const int o_stride_d, cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -39,20 +39,10 @@ extern "C" { ...@@ -39,20 +39,10 @@ extern "C" {
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cublas_gemm(const NVTETensor A, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
const NVTETensor B, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor D, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
const NVTETensor bias, int math_sm_count, cudaStream_t stream);
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
* *
...@@ -82,27 +72,14 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -82,27 +72,14 @@ void nvte_cublas_gemm(const NVTETensor A,
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced. * \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cublas_atomic_gemm(const NVTETensor A, void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor B, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
NVTETensor D, bool transb, bool grad, NVTETensor workspace, bool accumulate,
const NVTETensor bias, bool use_split_accumulator, int math_sm_count, int m_split,
NVTETensor pre_gelu_out, int n_split, bool gemm_producer, const NVTETensor counter,
bool transa, cudaStream_t stream);
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -42,16 +42,9 @@ extern "C" { ...@@ -42,16 +42,9 @@ extern "C" {
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_layernorm_fwd(const NVTETensor x, void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const NVTETensor gamma, const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
const NVTETensor beta, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier); NVTETensor barrier);
/*! \brief Compute LayerNorm with zero-centered gamma on the input. /*! \brief Compute LayerNorm with zero-centered gamma on the input.
...@@ -79,19 +72,11 @@ void nvte_layernorm_fwd(const NVTETensor x, ...@@ -79,19 +72,11 @@ void nvte_layernorm_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_layernorm1p_fwd(const NVTETensor x, void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
const NVTETensor gamma, const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
const NVTETensor beta, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier); NVTETensor barrier);
/*! \brief Compute backward of LayerNorm. /*! \brief Compute backward of LayerNorm.
* *
* This function computes the gradient of function: * This function computes the gradient of function:
...@@ -121,20 +106,14 @@ void nvte_layernorm1p_fwd(const NVTETensor x, ...@@ -121,20 +106,14 @@ void nvte_layernorm1p_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32! const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32! const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dgamma, NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor dbeta, NVTETensor workspace, NVTETensor barrier);
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm with zero-centered gamma. /*! \brief Compute backward of LayerNorm with zero-centered gamma.
* *
...@@ -165,20 +144,14 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -165,20 +144,14 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32! const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32! const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma, NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
NVTETensor dbeta, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -44,18 +44,11 @@ extern "C" { ...@@ -44,18 +44,11 @@ extern "C" {
* \param[in] margin Scaling factor margin. * \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream. * \param[in] stream CUDA stream.
*/ */
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor scale, const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale_inv, const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
const NVTETensor scale_inv_mask, NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
NVTETensor updated_amax_history, cudaStream_t stream);
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char* amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
* *
...@@ -85,15 +78,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his ...@@ -85,15 +78,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
* \param[in] stream CUDA stream. * \param[in] stream CUDA stream.
*/ */
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer, const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> amax_histories, std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream);
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -43,15 +43,9 @@ extern "C" { ...@@ -43,15 +43,9 @@ extern "C" {
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_rmsnorm_fwd(const NVTETensor x, void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
const NVTETensor gamma, NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount,
const float epsilon, NVTETensor workspace, NVTETensor barrier);
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute RMSNorm with zero-centered gamma on the input. /*! \brief Compute RMSNorm with zero-centered gamma on the input.
* *
...@@ -79,15 +73,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, ...@@ -79,15 +73,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_rmsnorm1p_fwd(const NVTETensor x, void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon,
const NVTETensor gamma, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const float epsilon, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier);
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of RMSNorm. /*! \brief Compute backward of RMSNorm.
* *
...@@ -118,18 +106,10 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, ...@@ -118,18 +106,10 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_rmsnorm_bwd(const NVTETensor dz, void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor x, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
const NVTETensor rsigma, NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
const NVTETensor gamma, NVTETensor workspace, NVTETensor barrier);
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
/*! \brief Compute backward of RMSNorm with zero-centered gamma. /*! \brief Compute backward of RMSNorm with zero-centered gamma.
* *
...@@ -160,18 +140,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, ...@@ -160,18 +140,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
*/ */
void nvte_rmsnorm1p_bwd(const NVTETensor dz, void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
const NVTETensor x, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
const NVTETensor rsigma, NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount,
const NVTETensor gamma, NVTETensor workspace, NVTETensor barrier);
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#ifndef TRANSFORMER_ENGINE_SOFTMAX_H_ #ifndef TRANSFORMER_ENGINE_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SOFTMAX_H_ #define TRANSFORMER_ENGINE_SOFTMAX_H_
#include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -22,13 +23,8 @@ extern "C" { ...@@ -22,13 +23,8 @@ extern "C" {
* \param[in] scale_factor Scalar for the input tensor. * \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_softmax_forward( void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results,
const NVTETensor input, float scale_factor, cudaStream_t stream);
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward of the scaled softmax activation. /*! \brief Compute the backward of the scaled softmax activation.
* *
...@@ -42,14 +38,8 @@ void nvte_scaled_softmax_forward( ...@@ -42,14 +38,8 @@ void nvte_scaled_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor. * \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_softmax_backward( void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results,
const NVTETensor incoming_grads, NVTETensor output_grads, float scale_factor, cudaStream_t stream);
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute scaled masked softmax activation on the input. /*! \brief Compute scaled masked softmax activation on the input.
* *
...@@ -59,14 +49,9 @@ void nvte_scaled_softmax_backward( ...@@ -59,14 +49,9 @@ void nvte_scaled_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor. * \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_masked_softmax_forward( void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask,
const NVTETensor input, NVTETensor softmax_results, float scale_factor,
const NVTETensor mask, cudaStream_t stream);
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward of the scaled masked softmax activation. /*! \brief Compute the backward of the scaled masked softmax activation.
* *
...@@ -80,14 +65,9 @@ void nvte_scaled_masked_softmax_forward( ...@@ -80,14 +65,9 @@ void nvte_scaled_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor. * \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_masked_softmax_backward( void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor incoming_grads, const NVTETensor softmax_results, NVTETensor output_grads,
const NVTETensor softmax_results, float scale_factor, cudaStream_t stream);
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input. /*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input.
* *
...@@ -96,13 +76,9 @@ void nvte_scaled_masked_softmax_backward( ...@@ -96,13 +76,9 @@ void nvte_scaled_masked_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor. * \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_upper_triang_masked_softmax_forward( void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
const NVTETensor input, NVTETensor softmax_results, float scale_factor,
NVTETensor softmax_results, cudaStream_t stream);
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask. /*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask.
* *
...@@ -116,14 +92,10 @@ void nvte_scaled_upper_triang_masked_softmax_forward( ...@@ -116,14 +92,10 @@ void nvte_scaled_upper_triang_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor. * \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_upper_triang_masked_softmax_backward( void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor incoming_grads, const NVTETensor softmax_results,
const NVTETensor softmax_results, NVTETensor output_grads, float scale_factor,
NVTETensor output_grads, cudaStream_t stream);
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. /*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
* *
...@@ -132,13 +104,9 @@ void nvte_scaled_upper_triang_masked_softmax_backward( ...@@ -132,13 +104,9 @@ void nvte_scaled_upper_triang_masked_softmax_backward(
* \param[in] scale_factor Scalar for the input tensor. * \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_aligned_causal_masked_softmax_forward( void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
const NVTETensor input, NVTETensor softmax_results,
NVTETensor softmax_results, float scale_factor, cudaStream_t stream);
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. /*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
* *
...@@ -152,13 +120,10 @@ void nvte_scaled_aligned_causal_masked_softmax_forward( ...@@ -152,13 +120,10 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(
* \param[in] scale_factor Scalar for the output tensor. * \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_scaled_aligned_causal_masked_softmax_backward( void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads,
const NVTETensor incoming_grads, const NVTETensor softmax_results,
const NVTETensor softmax_results, NVTETensor output_grads, float scale_factor,
NVTETensor output_grads, cudaStream_t stream);
float scale_factor,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ #ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ #define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
#include <stddef.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <stddef.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
...@@ -22,15 +22,15 @@ extern "C" { ...@@ -22,15 +22,15 @@ extern "C" {
* \brief TE datatype. * \brief TE datatype.
*/ */
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */ kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 32-bit integer */ kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */ kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */ kNVTENumTypes /*!< Number of supported types */
}; };
/*! \struct NVTEShape /*! \struct NVTEShape
...@@ -49,7 +49,7 @@ struct NVTEShape { ...@@ -49,7 +49,7 @@ struct NVTEShape {
* to data of a given shape and type. It does not own the * to data of a given shape and type. It does not own the
* memory it points to. * memory it points to.
*/ */
typedef void* NVTETensor; typedef void *NVTETensor;
/*! \brief Create a new TE tensor. /*! \brief Create a new TE tensor.
* *
...@@ -66,12 +66,8 @@ typedef void* NVTETensor; ...@@ -66,12 +66,8 @@ typedef void* NVTETensor;
* *
* \return A new TE tensor. * \return A new TE tensor.
*/ */
NVTETensor nvte_create_tensor(void *dptr, NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype,
const NVTEShape shape, float *amax_dptr, float *scale_dptr, float *scale_inv_dptr);
const NVTEDType dtype,
float *amax_dptr,
float *scale_dptr,
float *scale_inv_dptr);
/*! \brief Destroy a TE tensor. /*! \brief Destroy a TE tensor.
* *
...@@ -144,11 +140,11 @@ struct NVTETensorPack { ...@@ -144,11 +140,11 @@ struct NVTETensorPack {
/*! \brief Create `tensors` in NVTETensorPack. /*! \brief Create `tensors` in NVTETensorPack.
*/ */
void nvte_tensor_pack_create(NVTETensorPack* pack); void nvte_tensor_pack_create(NVTETensorPack *pack);
/*! \brief Destroy `tensors` in NVTETensorPack. /*! \brief Destroy `tensors` in NVTETensorPack.
*/ */
void nvte_tensor_pack_destroy(NVTETensorPack* pack); void nvte_tensor_pack_destroy(NVTETensorPack *pack);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
...@@ -164,12 +160,12 @@ namespace transformer_engine { ...@@ -164,12 +160,12 @@ namespace transformer_engine {
* \brief TE datatype. * \brief TE datatype.
*/ */
enum class DType { enum class DType {
kByte = 0, kByte = 0,
kInt32 = 1, kInt32 = 1,
kInt64 = 2, kInt64 = 2,
kFloat32 = 3, kFloat32 = 3,
kFloat16 = 4, kFloat16 = 4,
kBFloat16 = 5, kBFloat16 = 5,
kFloat8E4M3 = 6, kFloat8E4M3 = 6,
kFloat8E5M2 = 7, kFloat8E5M2 = 7,
kNumTypes kNumTypes
...@@ -193,11 +189,10 @@ class TensorWrapper { ...@@ -193,11 +189,10 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value. * \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value. * \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/ */
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr)
float *scale_inv_dptr = nullptr) : : tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype), amax_dptr,
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype), scale_dptr, scale_inv_dptr)) {}
amax_dptr, scale_dptr, scale_inv_dptr)) {}
/*! \brief Constructs new TensorWrapper. /*! \brief Constructs new TensorWrapper.
* *
...@@ -214,9 +209,9 @@ class TensorWrapper { ...@@ -214,9 +209,9 @@ class TensorWrapper {
*/ */
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype, TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) : float *scale_inv_dptr = nullptr)
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr,
amax_dptr, scale_dptr, scale_inv_dptr) {} scale_inv_dptr) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
* *
...@@ -225,11 +220,9 @@ class TensorWrapper { ...@@ -225,11 +220,9 @@ class TensorWrapper {
TensorWrapper() : TensorWrapper(nullptr, std::vector<size_t>(), DType::kFloat32) {} TensorWrapper() : TensorWrapper(nullptr, std::vector<size_t>(), DType::kFloat32) {}
/*! \brief TensorWrapper destructor. */ /*! \brief TensorWrapper destructor. */
~TensorWrapper() { ~TensorWrapper() { nvte_destroy_tensor(tensor_); }
nvte_destroy_tensor(tensor_);
}
TensorWrapper& operator=(const TensorWrapper &other) = delete; TensorWrapper &operator=(const TensorWrapper &other) = delete;
TensorWrapper(const TensorWrapper &other) = delete; TensorWrapper(const TensorWrapper &other) = delete;
/*! \brief Constructs new TensorWrapper from existing TensorWrapper. /*! \brief Constructs new TensorWrapper from existing TensorWrapper.
...@@ -249,7 +242,7 @@ class TensorWrapper { ...@@ -249,7 +242,7 @@ class TensorWrapper {
* *
* \param[in,out] other The source of the data. * \param[in,out] other The source of the data.
*/ */
TensorWrapper& operator=(TensorWrapper &&other) { TensorWrapper &operator=(TensorWrapper &&other) {
if (this == &other) return *this; if (this == &other) return *this;
nvte_destroy_tensor(tensor_); nvte_destroy_tensor(tensor_);
tensor_ = other.tensor_; tensor_ = other.tensor_;
...@@ -261,9 +254,7 @@ class TensorWrapper { ...@@ -261,9 +254,7 @@ class TensorWrapper {
* *
* \return NVTETensor held by this TensorWrapper. * \return NVTETensor held by this TensorWrapper.
*/ */
NVTETensor data() const noexcept { NVTETensor data() const noexcept { return tensor_; }
return tensor_;
}
/*! \brief Get the shape of this TensorWrapper. /*! \brief Get the shape of this TensorWrapper.
* *
......
...@@ -28,10 +28,8 @@ extern "C" { ...@@ -28,10 +28,8 @@ extern "C" {
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose(const NVTETensor input, void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream);
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Transpose the input. /*! \brief Transpose the input.
* *
...@@ -39,9 +37,7 @@ void nvte_cast_transpose(const NVTETensor input, ...@@ -39,9 +37,7 @@ void nvte_cast_transpose(const NVTETensor input,
* \param[out] transposed_output Result of the transpose. Shape: [H, N]. * \param[out] transposed_output Result of the transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaStream_t stream);
NVTETensor transposed_output,
cudaStream_t stream);
/*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. /*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension.
* *
...@@ -61,11 +57,8 @@ void nvte_transpose(const NVTETensor input, ...@@ -61,11 +57,8 @@ void nvte_transpose(const NVTETensor input,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. /*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension.
...@@ -84,11 +77,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -84,11 +77,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fp8_transpose_dbias(const NVTETensor input, void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors. /*! \brief Cast and transpose multiple tensors.
* *
...@@ -105,10 +95,8 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, ...@@ -105,10 +95,8 @@ void nvte_fp8_transpose_dbias(const NVTETensor input,
* of tensors in input_list. * of tensors in input_list.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_multi_cast_transpose(size_t num_tensors, void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
const NVTETensor* input_list, NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally, /*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally,
...@@ -131,50 +119,29 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -131,50 +119,29 @@ void nvte_multi_cast_transpose(size_t num_tensors,
* first dimension. Shape: [H]. * first dimension. Shape: [H].
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/ */
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor transposed_output,
NVTETensor dbias, void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input,
NVTETensor workspace, NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream); NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. /*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
* *
...@@ -189,38 +156,28 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, ...@@ -189,38 +156,28 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/ */
void nvte_dgeglu_cast_transpose(const NVTETensor input, void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream); cudaStream_t stream);
void nvte_dswiglu_cast_transpose(const NVTETensor input, void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, cudaStream_t stream);
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dreglu_cast_transpose(const NVTETensor input, void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream); cudaStream_t stream);
void nvte_dqgeglu_cast_transpose(const NVTETensor input, void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, cudaStream_t stream);
NVTETensor transposed_output,
cudaStream_t stream);
void nvte_dsreglu_cast_transpose(const NVTETensor input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, cudaStream_t stream);
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "../common.h" #include "../common.h"
...@@ -21,113 +22,107 @@ namespace layer_norm { ...@@ -21,113 +22,107 @@ namespace layer_norm {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params> template <typename Params>
struct LaunchParams{ struct LaunchParams {
size_t workspace_bytes; size_t workspace_bytes;
size_t barrier_size; size_t barrier_size;
int multiprocessorCount; int multiprocessorCount;
cudaStream_t stream; cudaStream_t stream;
Params params; Params params;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase { struct ParamsBase {
ParamsBase() ParamsBase()
: ctas_per_col(0) : ctas_per_col(0),
, rows(0) rows(0),
, cols(0) cols(0),
, x(nullptr) x(nullptr),
, mu(nullptr) mu(nullptr),
, rs(nullptr) rs(nullptr),
, gamma(nullptr) gamma(nullptr),
, workspace(nullptr) workspace(nullptr),
, barrier(nullptr) barrier(nullptr),
, zero_centered_gamma(false) {} zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. int ctas_per_col;
int ctas_per_col; // Size of CTA group.
// Size of CTA group. int ctas_per_row;
int ctas_per_row;
// Input is interpreted as matrix. We normalize across columns.
// Input is interpreted as matrix. We normalize across columns. int rows;
int rows; int cols;
int cols;
// Common data pointers.
// Common data pointers. void *x;
void *x; void *mu;
void *mu; void *rs;
void *rs; void *gamma;
void *gamma;
// Multi-CTA workspace in gmem.
// Multi-CTA workspace in gmem. void *workspace;
void *workspace;
// Multi-CTA sync barriers in gmem.
// Multi-CTA sync barriers in gmem. int *barrier;
int *barrier;
// Whether gamma is centered around 0
// Whether gamma is centered around 0 bool zero_centered_gamma;
bool zero_centered_gamma;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase { struct FwdParams : public ParamsBase {
FwdParams() FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {}
: ParamsBase()
, z(nullptr) // Output of LN FWD.
, beta(nullptr) void *z;
, epsilon(0.f) void *beta;
, fp8_out(false) {} float epsilon;
// Output of LN FWD. // Scaling factor
void *z; void *scale;
void *beta;
float epsilon; // AMax output
void *amax;
// Scaling factor
void *scale; // Whether to compute scale and amax
bool fp8_out;
// AMax output
void *amax;
// Whether to compute scale and amax
bool fp8_out;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase { struct BwdParams : public ParamsBase {
BwdParams() BwdParams()
: ParamsBase() : ParamsBase(),
, dz(nullptr) dz(nullptr),
, dbeta_part(nullptr) dbeta_part(nullptr),
, dgamma_part(nullptr) dgamma_part(nullptr),
, dx(nullptr) dx(nullptr),
, dbeta(nullptr) dbeta(nullptr),
, dgamma(nullptr) {} dgamma(nullptr) {}
// Input: gradient wrt. LN FWD output. // Input: gradient wrt. LN FWD output.
void *dz; void *dz;
// Workspace for Wgrad pre-reduction. // Workspace for Wgrad pre-reduction.
void *dbeta_part; void *dbeta_part;
void *dgamma_part; void *dgamma_part;
// Output: Dgrad. // Output: Dgrad.
void *dx; void *dx;
// Output: Wgrad. // Output: Wgrad.
void *dbeta; void *dbeta;
void *dgamma; void *dgamma;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>; using FwdFunction = std::function<void(LaunchParams<FwdParams> &, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>; using BwdFunction = std::function<void(LaunchParams<BwdParams> &, const bool)>;
using FunctionKey = uint64_t; using FunctionKey = uint64_t;
using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>; using FwdTunedRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>; using BwdTunedRegistry = std::unordered_map<FunctionKey, BwdFunction>;
...@@ -141,96 +136,96 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS; ...@@ -141,96 +136,96 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template <typename T>
struct TypeId{}; struct TypeId {};
template<> template <>
struct TypeId<fp16>{ struct TypeId<fp16> {
constexpr static uint32_t Value = 0; constexpr static uint32_t Value = 0;
}; };
template<> template <>
struct TypeId<bf16>{ struct TypeId<bf16> {
constexpr static uint32_t Value = 1; constexpr static uint32_t Value = 1;
}; };
template<> template <>
struct TypeId<fp32>{ struct TypeId<fp32> {
constexpr static uint32_t Value = 2; constexpr static uint32_t Value = 2;
}; };
template<> template <>
struct TypeId<fp8e4m3>{ struct TypeId<fp8e4m3> {
constexpr static uint32_t Value = 3; constexpr static uint32_t Value = 3;
}; };
template<typename T, int S> template <typename T, int S>
struct Type2Key{ struct Type2Key {
constexpr static uint32_t Value = TypeId<T>::Value << S; constexpr static uint32_t Value = TypeId<T>::Value << S;
}; };
template<typename T> template <typename T>
struct WeightType2Key : public Type2Key<T, 0>{}; struct WeightType2Key : public Type2Key<T, 0> {};
template<typename T> template <typename T>
struct InputType2Key : public Type2Key<T, 2>{}; struct InputType2Key : public Type2Key<T, 2> {};
template<typename T> template <typename T>
struct OutputType2Key : public Type2Key<T, 4>{}; struct OutputType2Key : public Type2Key<T, 4> {};
template<typename T> template <typename T>
struct ComputeType2Key : public Type2Key<T, 6>{}; struct ComputeType2Key : public Type2Key<T, 6> {};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C> template <typename W, typename I, typename O, typename C>
struct Types2Key{ struct Types2Key {
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value; OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){ constexpr static inline uint64_t get(const uint64_t hidden_size) {
constexpr uint64_t type_key = Value; constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size; return (type_key << 32) | hidden_size;
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar{ struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f){ explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({ key, f }); FWD_TUNED_FUNCS.insert({key, f});
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar{ struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f){ explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0); uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar{ struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f){ explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({ key, f }); BWD_TUNED_FUNCS.insert({key, f});
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar{ struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f){ explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = Types2Key<W, I, O, C>::get(0); uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "ln.h"
#include "../common.h" #include "../common.h"
#include "ln.h"
/* /*
...@@ -46,500 +46,411 @@ BwdGeneralRegistry BWD_GENERAL_FUNCS; ...@@ -46,500 +46,411 @@ BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(DType dtype) { uint32_t get_type_id(DType dtype) {
if ( dtype == DType::kFloat16 ) { if (dtype == DType::kFloat16) {
return TypeId<fp16>::Value; return TypeId<fp16>::Value;
} else if ( dtype == DType::kBFloat16 ) { } else if (dtype == DType::kBFloat16) {
return TypeId<bf16>::Value; return TypeId<bf16>::Value;
} else if ( dtype == DType::kFloat32 ) { } else if (dtype == DType::kFloat32) {
return TypeId<fp32>::Value; return TypeId<fp32>::Value;
} else if ( dtype == DType::kFloat8E4M3 ) { } else if (dtype == DType::kFloat8E4M3) {
return TypeId<fp8e4m3>::Value; return TypeId<fp8e4m3>::Value;
} else { } else {
NVTE_ERROR("Type not supported."); NVTE_ERROR("Type not supported.");
} }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) {
using namespace layer_norm; using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) |
(get_type_id(otype) << 4) | (get_type_id(ctype) << 6); (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size; uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key; return launcher_key;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(DType wtype, layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
DType itype, const layer_norm::FwdParams& params) {
DType otype, // Look for tuned kernel
DType ctype, auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
const layer_norm::FwdParams &params) { auto is_aligned = [](const void* ptr) -> bool {
// Look for tuned kernel // Assume vectorized memory accesses are <=16B
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
auto is_aligned = [](const void *ptr) -> bool { };
// Assume vectorized memory accesses are <=16B if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0; is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) &&
}; is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
if (params.rows % 4 == 0 return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
&& is_aligned(params.x) }
&& is_aligned(params.mu)
&& is_aligned(params.rs) // Pick general kernel
&& is_aligned(params.gamma) auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
&& is_aligned(params.beta) if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) {
&& is_aligned(params.z) NVTE_ERROR("FWD: Unsupported types.");
&& layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { }
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
} auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Pick general kernel // Hidden size is too big, need to use multi-CTA
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); return general_func_map.rbegin()->second;
if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { } else {
NVTE_ERROR("FWD: Unsupported types."); return func_iter->second;
} }
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(DType wtype, layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
DType itype, const layer_norm::BwdParams& params) {
DType otype, // Look for tuned kernel
DType ctype, auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
const layer_norm::BwdParams &params) { auto is_aligned = [](const void* ptr) -> bool {
// Look for tuned kernel // Assume vectorized memory accesses are <=16B
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
auto is_aligned = [](const void *ptr) -> bool { };
// Assume vectorized memory accesses are <=16B if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) &&
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0; is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) &&
}; is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) &&
if (params.rows % 4 == 0 is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) &&
&& is_aligned(params.x) layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
&& is_aligned(params.mu) return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
&& is_aligned(params.rs) }
&& is_aligned(params.gamma)
&& is_aligned(params.dz) // Pick general kernel
&& is_aligned(params.dx) auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
&& is_aligned(params.dbeta) if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) {
&& is_aligned(params.dgamma) NVTE_ERROR("BWD: Unsupported types.");
&& is_aligned(params.dbeta_part) }
&& is_aligned(params.dgamma_part) auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { auto func_iter = general_func_map.lower_bound(params.cols);
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); if (func_iter == general_func_map.end()) {
} // Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
// Pick general kernel } else {
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); return func_iter->second;
if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { }
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
size_t product(const std::vector<size_t>& shape) {
size_t product(const std::vector<size_t> &shape) { size_t ret = 1;
size_t ret = 1; for (auto s : shape) {
for (auto s : shape) { ret *= s;
ret *= s; }
} return ret;
return ret;
} }
} // namespace layer_norm } // namespace layer_norm
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void layernorm_fwd(const Tensor& x, // BxSxhidden_size void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size const Tensor& beta, // hidden_size
const float epsilon, const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream,
Tensor* z, const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
Tensor* mu,
Tensor* rsigma,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
const auto itype = x.data.dtype; const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype; const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype; const auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32; const auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1]; const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0]; const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape == beta.data.shape); NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(hidden_size == cols); NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{ rows }); NVTE_CHECK(mu->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(mu->data.dtype == ctype); NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{ rows }); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype); NVTE_CHECK(rsigma->data.dtype == ctype);
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params; layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream; launch_params.stream = stream;
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params; layer_norm::FwdParams& params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.data.dptr; params.x = x.data.dptr;
params.mu = mu->data.dptr; params.mu = mu->data.dptr;
params.rs = rsigma->data.dptr; params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr; params.gamma = gamma.data.dptr;
params.beta = beta.data.dptr; params.beta = beta.data.dptr;
params.z = z->data.dptr; params.z = z->data.dptr;
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher. // Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) { if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1; launch_params.workspace_bytes = 1;
} }
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr); NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte; workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = { launch_params.workspace_bytes }; workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = layer_norm::DType::kInt32; barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size }; barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(z->amax.shape) *
typeToSize(z->amax.dtype), stream);
}
if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
return; return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
} }
void layernorm_bwd(const Tensor& dz, void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma,
const Tensor& x, const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta,
const Tensor& mu, Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream,
const Tensor& rsigma, const int multiprocessorCount, Tensor* workspace, Tensor* barrier,
const Tensor& gamma, const bool zero_centered_gamma) {
Tensor* dx, using namespace transformer_engine;
Tensor* dgamma,
Tensor* dbeta, auto itype = x.data.dtype;
Tensor* dgamma_part, auto wtype = gamma.data.dtype;
Tensor* dbeta_part, auto otype = wtype;
cudaStream_t stream, auto ctype = DType::kFloat32;
const int multiprocessorCount,
Tensor* workspace, NVTE_CHECK(dz.data.dtype == otype);
Tensor* barrier, NVTE_CHECK(mu.data.dtype == ctype);
const bool zero_centered_gamma NVTE_CHECK(rsigma.data.dtype == ctype);
) {
using namespace transformer_engine; NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
auto itype = x.data.dtype; auto rows = x.data.shape[0];
auto wtype = gamma.data.dtype; auto cols = x.data.shape[1];
auto otype = wtype;
auto ctype = DType::kFloat32; auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.data.dtype == ctype); NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape); NVTE_CHECK(dx->data.shape == x.data.shape);
auto rows = x.data.shape[0]; NVTE_CHECK(dx->data.dtype == x.data.dtype);
auto cols = x.data.shape[1];
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
auto hidden_size = gamma.data.shape[0]; NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(mu.data.shape[0] == rows); NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(mu.data.shape == rsigma.data.shape); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
NVTE_CHECK(gamma.data.shape[0] == cols); layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
NVTE_CHECK(dx->data.shape == x.data.shape); launch_params.multiprocessorCount = multiprocessorCount;
NVTE_CHECK(dx->data.dtype == x.data.dtype);
// Set the kernel runtime parameters.
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); layer_norm::BwdParams& params = launch_params.params;
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); params.rows = rows;
params.cols = cols;
NVTE_CHECK(dbeta->data.shape == gamma.data.shape); params.x = x.data.dptr;
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); params.mu = mu.data.dptr;
params.rs = rsigma.data.dptr;
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params; params.gamma = gamma.data.dptr;
launch_params.stream = stream; params.dz = dz.data.dptr;
launch_params.multiprocessorCount = multiprocessorCount; params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr;
// Set the kernel runtime parameters. params.dgamma = dgamma->data.dptr;
layer_norm::BwdParams &params = launch_params.params; params.dbeta_part = dbeta_part->data.dptr;
params.rows = rows; params.dgamma_part = dgamma_part->data.dptr;
params.cols = cols; params.zero_centered_gamma = zero_centered_gamma;
params.x = x.data.dptr;
params.mu = mu.data.dptr; auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr; // Query the kernel-specific launch parameters.
params.dz = dz.data.dptr; launcher(launch_params, true);
params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr; // Populate shape and dtypes for FW to allocate memory
params.dgamma = dgamma->data.dptr; if (dgamma_part->data.dptr == nullptr) {
params.dbeta_part = dbeta_part->data.dptr; NVTE_CHECK(dbeta_part->data.dptr == nullptr);
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma; dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); hidden_size};
// Query the kernel-specific launch parameters. dbeta_part->data.dtype = ctype;
launcher(launch_params, true); dbeta_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size};
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) { workspace->data.dtype = layer_norm::DType::kByte;
NVTE_CHECK(dbeta_part->data.dptr == nullptr); workspace->data.shape = {launch_params.workspace_bytes};
dgamma_part->data.dtype = ctype; barrier->data.dtype = layer_norm::DType::kInt32;
dgamma_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col), barrier->data.shape = {launch_params.barrier_size};
hidden_size };
return;
dbeta_part->data.dtype = ctype; } else {
dbeta_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col), NVTE_CHECK(dbeta_part->data.dptr != nullptr);
hidden_size }; auto pdw_shape =
std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = { launch_params.workspace_bytes }; NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
barrier->data.dtype = layer_norm::DType::kInt32; NVTE_CHECK(dbeta_part->data.dtype == ctype);
barrier->data.shape = { launch_params.barrier_size }; NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
return;
} else { if (launch_params.barrier_size > 0) {
NVTE_CHECK(dbeta_part->data.dptr != nullptr); NVTE_CHECK(barrier->data.dptr != nullptr);
auto pdw_shape = std::vector<size_t>{ NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size}; NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape); if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(dbeta_part->data.dtype == ctype); NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape); NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
} NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr); // Tensor checks are delayed here in order to recover workspace sizes with null data
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); CheckInputTensor(dz, "dz");
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size }); CheckInputTensor(x, "x");
} CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
if (launch_params.workspace_bytes > 0) { CheckInputTensor(gamma, "gamma");
NVTE_CHECK(workspace->data.dptr != nullptr); CheckOutputTensor(*dx, "dx");
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); CheckOutputTensor(*dgamma, "dgamma");
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes }); CheckOutputTensor(*dbeta, "dbeta");
}
if (launch_params.barrier_size > 0) {
// Tensor checks are delayed here in order to recover workspace sizes with null data params.workspace = workspace->data.dptr;
CheckInputTensor(dz, "dz"); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
CheckInputTensor(x, "x"); cudaMemsetAsync(params.barrier, 0,
CheckInputTensor(mu, "mu"); layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
CheckInputTensor(rsigma, "rsigma"); stream);
CheckInputTensor(gamma, "gamma"); }
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma"); // Launch the kernel.
CheckOutputTensor(*dbeta, "dbeta"); launcher(launch_params, false);
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
launcher(launch_params, false);
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size const NVTETensor beta, // hidden_size
const float epsilon, const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
NVTETensor z, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd); NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
*reinterpret_cast<const Tensor*>(beta), reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
epsilon, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(z), reinterpret_cast<Tensor*>(barrier), false);
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
false);
} }
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32! const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32! const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part,
NVTETensor dgamma, NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount,
NVTETensor dbeta, NVTETensor workspace, NVTETensor barrier) {
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_bwd); NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
*reinterpret_cast<const Tensor*>(rsigma), reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
reinterpret_cast<Tensor*>(dx), stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(barrier), false);
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
false);
} }
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size const NVTETensor beta, // hidden_size
const float epsilon, const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
NVTETensor z, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_fwd); NVTE_API_CALL(nvte_layernorm1p_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
*reinterpret_cast<const Tensor*>(beta), reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), stream,
epsilon, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(z), reinterpret_cast<Tensor*>(barrier), true);
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
} }
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32! const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32! const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta,
NVTETensor dgamma, NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream,
NVTETensor dbeta, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_bwd); NVTE_API_CALL(nvte_layernorm1p_bwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
*reinterpret_cast<const Tensor*>(rsigma), reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dgamma_part), reinterpret_cast<Tensor*>(dbeta_part),
reinterpret_cast<Tensor*>(dx), stream, multiprocessorCount, reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(barrier), true);
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
} }
...@@ -7,605 +7,570 @@ ...@@ -7,605 +7,570 @@
#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_
#include "ln.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "ln.h"
namespace transformer_engine { namespace transformer_engine {
namespace layer_norm { namespace layer_norm {
using namespace transformer_engine; using namespace transformer_engine;
template<typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel(
void ln_bwd_tuned_kernel(layer_norm::BwdParams params) { layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS }; enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec; using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer; using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type; using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[]; extern __shared__ char smem_[];
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N; const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N; const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane; const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS]; Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS]; Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum)); memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_); compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum; Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS); constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS]; Wvec gamma[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for ( int it = 0; it < LDGS; it++ ) { for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
} }
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
const compute_t x_tmp = x[it].data.elt[jt];
const compute_t y_tmp = rs_r * (x_tmp - mu_r);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if ( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for ( int it = 0; it < ROWS_PER_CTA; it++ ) {
for ( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r; compute_t dy[LDGS * NUM_ELTS];
#pragma unroll compute_t y[LDGS * NUM_ELTS];
for ( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx); compute_t mdy_local = 0.f;
idx += THREADS_PER_ROW; compute_t mdyy_local = 0.f;
} #pragma unroll
__syncthreads(); for (int it = 0; it < LDGS; it++) {
compute_t cta_dzy_sum[NUM_RES]; #pragma unroll
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); for (int jt = 0; jt < NUM_ELTS; jt++) {
for ( int it = 0; it < ROWS_PER_CTA; it++ ) { const compute_t x_tmp = x[it].data.elt[jt];
for ( int jt = 0; jt < NUM_RES; jt++ ) { const compute_t y_tmp = rs_r * (x_tmp - mu_r);
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
} compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
} dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for ( int jt = 0; jt < NUM_RES; jt++ ) { mdy_local += dy_tmp;
*dgamma_part = cta_dzy_sum[jt]; mdyy_local += dy_tmp * y_tmp;
dgamma_part += Ktraits::THREADS_PER_CTA;
} dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx; reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
for ( int jt = 0; jt < NUM_RES; jt++ ) { mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
*dbeta_part = cta_dz_sum[jt]; mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
dbeta_part += Ktraits::THREADS_PER_CTA;
} Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
} }
} } // end: grid stride loop
if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
template<typename Kernel_traits> idx = warp_m * Ktraits::VEC_COLS + tid_r;
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) #pragma unroll
void ln_bwd_finalize_tuned_kernel(BwdParams params) { for (int it = 0; it < LDGS; it++) {
using compute_t = typename Kernel_traits::compute_t; dzy_sum[it].store_to(smem_wgrad, idx);
using weight_t = typename Kernel_traits::weight_t; idx += THREADS_PER_ROW;
using index_t = typename Kernel_traits::index_t; }
using Reducer = typename Kernel_traits::Reducer; __syncthreads();
using reduce_t = typename Reducer::Type; compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
Sum<reduce_t> sum; for (int it = 0; it < ROWS_PER_CTA; it++) {
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; for (int jt = 0; jt < NUM_RES; jt++) {
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; }
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for ( uint32_t col = c, col_out = c_out;
col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for ( uint32_t row = warp; row < params.ctas_per_col;
row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_; compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt];
const int write_row = warp; dgamma_part += Ktraits::THREADS_PER_CTA;
const int write_col = lane ^ write_row; }
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE
+ Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for ( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done. compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
__syncthreads(); for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dz_sum[jt];
// Pack and store: 2-wide stores with half the threads. dbeta_part += Ktraits::THREADS_PER_CTA;
if ( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] =
Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
} }
}
} }
template<typename Ktraits> template <typename Kernel_traits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel(
void ln_bwd_general_kernel(layer_norm::BwdParams params) { BwdParams params) {
enum { LDGS = Ktraits::LDGS }; using compute_t = typename Kernel_traits::compute_t;
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; using weight_t = typename Kernel_traits::weight_t;
enum { WARPS_M = Ktraits::WARPS_M }; using index_t = typename Kernel_traits::index_t;
enum { WARPS_N = Ktraits::WARPS_N }; using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t; Sum<reduce_t> sum;
using compute_t = typename Ktraits::compute_t; enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
using output_t = typename Ktraits::output_t; enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec; __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; constexpr uint32_t bidm = 0;
using Cvec = typename Ktraits::Cvec;
const uint32_t bidn = blockIdx.x;
const index_t tidx = threadIdx.x; const uint32_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP; const uint32_t warp = tidx / THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const uint32_t lane = tidx % THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N; Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const index_t bdimm = WARPS_M; const uint32_t c = bidn * THREADS_PER_WARP + lane;
const index_t bdimn = WARPS_N * THREADS_PER_WARP; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
const index_t bidm = blockIdx.x / params.ctas_per_row; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
const index_t bidn = blockIdx.x % params.ctas_per_row; for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
const index_t gdimm = bdimm * params.ctas_per_col; // Each thread sums over NUM_ELT columns.
const index_t gdimn = bdimn * params.ctas_per_row; Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
const index_t gidm = bidm * bdimm + warp_m; memset(&dgamma_local, 0, sizeof(dgamma_local));
const index_t gidn = (bidn * THREADS_PER_WARP memset(&dbeta_local, 0, sizeof(dbeta_local));
+ warp_n * params.ctas_per_row * THREADS_PER_WARP for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
+ lane); // Order threads by warp x cta x lane index_t idx = row * Kernel_traits::COLS + col;
// Objects for weight grads Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
Cvec dzy_sum[LDGS]; dbeta_part.load_from(params.dbeta_part, idx);
Cvec dz_sum[LDGS]; dgamma_part.load_from(params.dgamma_part, idx);
memset(dzy_sum, 0, sizeof(dzy_sum)); #pragma unroll
memset(dz_sum, 0, sizeof(dz_sum)); for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
// Objects for stats reductions dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
using reduce_t = typename Ktraits::Reducer::Type; }
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
} }
for ( int cta_row = bidm * bdimm; void *smem_gamma = smem_;
cta_row < params.rows; void *smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
cta_row += gdimm ) {
const int row = cta_row + warp_m;
compute_t mu = 0.f;
compute_t rs = 0.f;
if ( row < params.rows ) {
mu = static_cast<const compute_t *>(params.mu)[row];
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS]; const int write_row = warp;
Cvec y[LDGS]; const int write_col = lane ^ write_row;
compute_t mdy = 0.f; const int write_idx = write_row * THREADS_PER_WARP + write_col;
compute_t mdyy = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row dgamma_local.store_to(smem_gamma, write_idx);
reduce_t result = reducer.allreduce({mdy, mdyy}, sum); dbeta_local.store_to(smem_beta, write_idx);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn; __syncthreads();
// Compute dx // It would be probably safe to reuse the first row of smem_beta and smem_gamma
#pragma unroll void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
for ( int it = 0, col = gidn * NUM_ELTS; void *smem_beta_out =
it < LDGS && row < params.rows && col < params.cols; &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
it++, col += gdimn * NUM_ELTS ) {
Ivec dx; // More than one iter iff ROWS_PER_CTA < 32.
#pragma unroll for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { const int read_row = lane;
compute_t dy_ij = dy[it].data.elt[jt]; const int read_col = w ^ read_row;
compute_t y_ij = y[it].data.elt[jt]; const int read_idx = read_row * THREADS_PER_WARP + read_col;
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
} memset(&dbeta_local, 0, sizeof(dbeta_local));
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); memset(&dgamma_local, 0, sizeof(dgamma_local));
}
// Load beta and gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) {
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
} }
if constexpr ( WARPS_M == 1 ) { // All writes done.
// Write out local weight grad contributions __syncthreads();
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP+1];
// Reduce dz
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
__syncthreads();
#pragma unroll
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS;
it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) {
#pragma unroll
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dz_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dz_sum[it].store_to_elts(params.dbeta_part,
bidm * params.cols + col,
params.cols - col);
}
// Reduce dzy // Pack and store: 2-wide stores with half the threads.
__syncthreads(); if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
#pragma unroll using src_t = typename TypeToVec2<compute_t>::Type;
for ( int it = 0, col = gidn * NUM_ELTS; using dst_t = typename TypeToVec2<weight_t>::Type;
it < LDGS && col < params.cols; Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
it++, col += gdimn * NUM_ELTS ) { Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
if ( it != warp_m ) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); dgamma_vec2.load_from(smem_gamma_out, lane);
} dbeta_vec2.load_from(smem_beta_out, lane);
} #pragma unroll
__syncthreads(); for (int it = 0; it < NUM_ELT; it++) {
#pragma unroll dgamma_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
for ( int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; dbeta_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);
it < LDGS && col < params.cols; }
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS ) { dgamma_out2.store_to(params.dgamma, col_out);
#pragma unroll dbeta_out2.store_to(params.dbeta, col_out);
for ( int kt = 0; kt < WARPS_M; kt++ ) {
if ( kt != warp_m ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dzy_sum[it].data.elt[jt]
+= vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
}
}
dzy_sum[it].store_to_elts(params.dgamma_part,
bidm * params.cols + col,
params.cols - col);
}
} }
}
} }
template< template <typename Ktraits>
typename weight_t, __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel(
typename compute_t, layer_norm::BwdParams params) {
uint32_t WARPS_M, enum { LDGS = Ktraits::LDGS };
uint32_t WARPS_N, enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
uint32_t BYTES_PER_LDG, enum { WARPS_M = Ktraits::WARPS_M };
uint32_t THREADS_PER_WARP enum { WARPS_N = Ktraits::WARPS_N };
>
__global__ __launch_bounds__(WARPS_M * WARPS_N * THREADS_PER_WARP) using input_t = typename Ktraits::input_t;
void ln_bwd_finalize_general_kernel(layer_norm::BwdParams params) { using weight_t = typename Ktraits::weight_t;
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using compute_t = typename Ktraits::compute_t;
using Wvec = Vec<weight_t, NUM_ELTS>; using output_t = typename Ktraits::output_t;
using Cvec = Vec<compute_t, NUM_ELTS>; using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
const int lane = threadIdx.x % THREADS_PER_WARP; using Ovec = typename Ktraits::Ovec;
const int warp_m = threadIdx.y; using Wvec = typename Ktraits::Wvec;
const int warp_n = threadIdx.x / THREADS_PER_WARP; using Cvec = typename Ktraits::Cvec;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
const index_t tidx = threadIdx.x;
// Load grad contributions and accumulate locally const index_t lane = tidx % THREADS_PER_WARP;
Cvec dgamma, dbeta; const index_t warp = tidx / THREADS_PER_WARP;
dgamma.clear(); const index_t warp_m = warp / WARPS_N;
dbeta.clear(); const index_t warp_n = warp % WARPS_N;
for ( int row = warp_m;
row < params.ctas_per_col && col < params.cols; const index_t bdimm = WARPS_M;
row += WARPS_M ) { const index_t bdimn = WARPS_N * THREADS_PER_WARP;
Cvec dgamma_part, dbeta_part; const index_t bidm = blockIdx.x / params.ctas_per_row;
dgamma_part.load_from_elts(params.dgamma_part, const index_t bidn = blockIdx.x % params.ctas_per_row;
row * params.cols + col,
params.cols - col); const index_t gdimm = bdimm * params.ctas_per_col;
dbeta_part.load_from_elts(params.dbeta_part, const index_t gdimn = bdimn * params.ctas_per_row;
row * params.cols + col, const index_t gidm = bidm * bdimm + warp_m;
params.cols - col); const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
#pragma unroll lane); // Order threads by warp x cta x lane
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; // Objects for weight grads
dbeta.data.elt[jt] += dbeta_part.data.elt[jt]; Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
// Objects for stats reductions
using reduce_t = typename Ktraits::Reducer::Type;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<reduce_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t mu = 0.f;
compute_t rs = 0.f;
if (row < params.rows) {
mu = static_cast<const compute_t *>(params.mu)[row];
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij;
mdyy += dy_ij * y_ij;
dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}
// Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
if constexpr (WARPS_M == 1) {
// Write out local weight grad contributions
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col);
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
}
} else {
// Reduce weight grad contributions within CTA before writing
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
// Reduce dz
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
dz_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
}
__syncthreads();
#pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dz_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
} }
}
dz_sum[it].store_to_elts(params.dbeta_part, bidm * params.cols + col, params.cols - col);
} }
// Reduce dgamma within CTA // Reduce dzy
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP+1]; __syncthreads();
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); #pragma unroll
#pragma unroll for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) { it++, col += gdimn * NUM_ELTS) {
__syncthreads(); if (it != warp_m) {
if ( warp_m < nrows ) { dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
#pragma unroll }
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { }
vecs_shared[warp_m][warp_n][lane].data.elt[jt] __syncthreads();
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt]; #pragma unroll
} for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
}
} }
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
} }
if ( warp_m == 0 && col < params.cols ) { }
Wvec dgamma_out; }
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col); template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__
__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel(
layer_norm::BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally
Cvec dgamma, dbeta;
dgamma.clear();
dbeta.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part, dbeta_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
dbeta_part.load_from_elts(params.dbeta_part, row * params.cols + col, params.cols - col);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
dbeta.data.elt[jt] += dbeta_part.data.elt[jt];
} }
}
// Reduce dgamma within CTA // Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads(); __syncthreads();
dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]); if (warp_m < nrows) {
#pragma unroll #pragma unroll
for ( int nrows = WARPS_M / 2; nrows > 0; nrows /= 2 ) { for (int jt = 0; jt < NUM_ELTS; jt++) {
__syncthreads(); vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
if ( warp_m < nrows ) { vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
#pragma unroll }
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt]
+= vecs_shared[warp_m+nrows][warp_n][lane].data.elt[jt];
}
}
} }
if ( warp_m == 0 && col < params.cols ) { }
Wvec dbeta_out; if (warp_m == 0 && col < params.cols) {
vecs_shared[warp_m][warp_n][lane].to(dbeta_out); Wvec dgamma_out;
dbeta_out.store_to_elts(params.dbeta, col, params.cols - col); vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
// Reduce dgamma within CTA
__syncthreads();
dbeta.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads();
if (warp_m < nrows) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
}
} }
}
if (warp_m == 0 && col < params.cols) {
Wvec dbeta_out;
vecs_shared[warp_m][warp_n][lane].to(dbeta_out);
dbeta_out.store_to_elts(params.dbeta, col, params.cols - col);
}
} }
} // namespace layer_norm } // namespace layer_norm
......
...@@ -5,233 +5,154 @@ ...@@ -5,233 +5,154 @@
************************************************************************/ ************************************************************************/
#include "ln.h" #include "ln.h"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh" #include "ln_bwd_kernels.cuh"
#include "ln_kernel_traits.h"
using namespace transformer_engine::layer_norm; using namespace transformer_engine::layer_norm;
template< template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename weight_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
typename input_t, int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
typename output_t, void launch_tuned_(LaunchParams<BwdParams> &launch_params,
typename compute_t, const bool configure_params) { // NOLINT(*)
typename index_t, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
int HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
int CTAS_PER_ROW, auto kernel = &ln_bwd_tuned_kernel<Kernel_traits>;
int WARPS_M,
int WARPS_N, if (configure_params) {
int BYTES_PER_LDG_MAIN, int ctas_per_sm;
int BYTES_PER_LDG_FINAL cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
> &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*) launch_params.params.ctas_per_row = CTAS_PER_ROW;
using Kernel_traits = Kernel_traits<weight_t, launch_params.params.ctas_per_col =
input_t, launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
output_t, launch_params.barrier_size = 0;
compute_t, launch_params.workspace_bytes = 0;
index_t, if (Kernel_traits::CTAS_PER_ROW > 1) {
HIDDEN_SIZE, launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
CTAS_PER_ROW, launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
WARPS_M, Kernel_traits::CTAS_PER_ROW *
WARPS_N, sizeof(typename Kernel_traits::reduce_t) * 2;
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount
* ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
} }
auto stream = launch_params.stream; return;
auto ctas_per_col = launch_params.params.ctas_per_col; }
auto ctas_per_row = launch_params.params.ctas_per_row;
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
if ( ctas_per_row == 1 ) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>> Kernel_traits::SMEM_BYTES));
(launch_params.params); }
} else { auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col); auto ctas_per_col = launch_params.params.ctas_per_col;
dim3 block(Kernel_traits::THREADS_PER_CTA); auto ctas_per_row = launch_params.params.ctas_per_row;
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), if (ctas_per_row == 1) {
grid, kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
block, launch_params.params);
reinterpret_cast<void **>(&params_), } else {
Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>
(launch_params.params);
}
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(typename Kernel_traits::reduce_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) { void *params_ = reinterpret_cast<void *>(&launch_params.params);
kernel<<<grid, block, 0, stream>>>(launch_params.params); cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
} else { reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
void *params_ = reinterpret_cast<void *>(&launch_params.params); stream);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), }
grid,
block, using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t,
reinterpret_cast<void **>(&params_), output_t, compute_t, index_t,
0, 32 * 32, // THREADS_PER_CTA
stream); BYTES_PER_LDG_FINAL>;
}
auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
// Launch finalization kernel template <typename weight_t, typename input_t, typename output_t, typename compute_t,
constexpr uint32_t WARPS_M_FINAL = 4; typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
constexpr uint32_t WARPS_N_FINAL = 1; int BYTES_PER_LDG_FINAL>
constexpr uint32_t ELTS_N_PER_CTA_FINAL = (Kernel_traits::THREADS_PER_WARP void launch_general_(LaunchParams<BwdParams> &launch_params,
* WARPS_N_FINAL const bool configure_params) { // NOLINT(*)
* BYTES_PER_LDG_FINAL auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
/ sizeof(compute_t));
auto kernel_final = &ln_bwd_finalize_general_kernel<weight_t, // Instantiate kernel
compute_t, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
WARPS_M_FINAL, 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
WARPS_N_FINAL, auto kernel = &ln_bwd_general_kernel<Kernel_traits>;
BYTES_PER_LDG_FINAL,
Kernel_traits::THREADS_PER_WARP>; // Configure kernel params
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); const int rows = launch_params.params.rows;
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); const int cols = launch_params.params.cols;
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params); int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&ln_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \ #define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
BYTES_PER_LDG_FINALIZE) \ void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
LaunchParams<BwdParams> \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
&launch_params, \ WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
const bool configure_params) { \ configure_params); \
launch_tuned_<WTYPE, \ } \
ITYPE, \ static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
OTYPE, \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
CTYPE, \ ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
uint32_t, \
HIDDEN_SIZE, \ #define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
CTAS_PER_ROW, \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
WARPS_M, \ void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
WARPS_N, \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
BYTES_PER_LDG, \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \ } \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -252,9 +173,9 @@ REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ...@@ -252,9 +173,9 @@ REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
...@@ -263,11 +184,11 @@ REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); ...@@ -263,11 +184,11 @@ REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
...@@ -318,16 +239,16 @@ REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); ...@@ -318,16 +239,16 @@ REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
...@@ -336,9 +257,9 @@ REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); ...@@ -336,9 +257,9 @@ REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
......
...@@ -5,176 +5,128 @@ ...@@ -5,176 +5,128 @@
************************************************************************/ ************************************************************************/
#include "ln.h" #include "ln.h"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh" #include "ln_fwd_kernels.cuh"
#include "ln_kernel_traits.h"
using namespace transformer_engine::layer_norm; using namespace transformer_engine::layer_norm;
template< template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename weight_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
typename input_t, int BYTES_PER_LDG>
typename output_t, void launch_tuned_(LaunchParams<FwdParams> &launch_params,
typename compute_t, const bool configure_params) { // NOLINT(*)
typename index_t, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
int HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
int CTAS_PER_ROW, auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
int WARPS_M,
int WARPS_N, if (configure_params) {
int BYTES_PER_LDG int ctas_per_sm;
> cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
using Kernel_traits = Kernel_traits<weight_t, launch_params.params.ctas_per_row = CTAS_PER_ROW;
input_t, launch_params.params.ctas_per_col =
output_t, launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
compute_t, launch_params.barrier_size = 0;
index_t, launch_params.workspace_bytes = 0;
HIDDEN_SIZE, if (Kernel_traits::CTAS_PER_ROW > 1) {
CTAS_PER_ROW, launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
WARPS_M, launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
WARPS_N, Kernel_traits::CTAS_PER_ROW *
BYTES_PER_LDG sizeof(typename Kernel_traits::Stats::stats_t) * 2;
>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.multiprocessorCount *
ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if ( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if ( ctas_per_row == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA,
Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
} }
} return;
}
template<
typename weight_t, if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
typename input_t, NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
typename output_t, Kernel_traits::SMEM_BYTES_FWD));
typename compute_t, }
typename index_t, auto stream = launch_params.stream;
int HIDDEN_SIZE, auto ctas_per_col = launch_params.params.ctas_per_col;
int WARPS_M, auto ctas_per_row = launch_params.params.ctas_per_row;
int WARPS_N,
int BYTES_PER_LDG if (ctas_per_row == 1) {
> kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) launch_params.params);
using Kernel_traits = Kernel_traits<weight_t, } else {
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
1,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if ( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M),
max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col
* WARPS_M
* ctas_per_row
* sizeof(compute_t)
* 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if ( ctas_per_row == 1 ) { void *params_ = reinterpret_cast<void *>(&launch_params.params);
kernel<<<grid, block, 0, stream>>>(launch_params.params); cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
} else { Kernel_traits::SMEM_BYTES_FWD, stream);
void *params_ = reinterpret_cast<void *>(&launch_params.params); }
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), }
grid,
block, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
reinterpret_cast<void **>(&params_), typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
0, void launch_general_(LaunchParams<FwdParams> &launch_params,
stream); const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
} }
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ #define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
const bool configure_params) { \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \ WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \ } \
launch_params, configure_params); \ static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
} \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) #define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG) \
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
WARPS_M, WARPS_N, BYTES_PER_LDG) \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
LaunchParams<FwdParams> &launch_params, \ BYTES_PER_LDG>(launch_params, configure_params); \
const bool configure_params) { \ } \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \ static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
launch_params, configure_params); \ ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -187,21 +139,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); ...@@ -187,21 +139,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16);
...@@ -213,21 +165,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); ...@@ -213,21 +165,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16);
...@@ -239,21 +191,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); ...@@ -239,21 +191,21 @@ REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16);
...@@ -295,11 +247,11 @@ REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); ...@@ -295,11 +247,11 @@ REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
...@@ -337,17 +289,17 @@ REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); ...@@ -337,17 +289,17 @@ REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
...@@ -373,17 +325,17 @@ REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); ...@@ -373,17 +325,17 @@ REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
......
...@@ -9,306 +9,294 @@ ...@@ -9,306 +9,294 @@
#include <cfloat> #include <cfloat>
#include <cstdio> #include <cstdio>
#include "ln.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "ln.h"
namespace transformer_engine { namespace transformer_engine {
namespace layer_norm { namespace layer_norm {
using namespace transformer_engine; using namespace transformer_engine;
template<typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) {
void ln_fwd_tuned_kernel(FwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_M = Ktraits::WARPS_M }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS };
enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using output_t = typename Ktraits::output_t; using index_t = typename Ktraits::index_t;
using index_t = typename Ktraits::index_t; using compute_t = typename Ktraits::compute_t;
using compute_t = typename Ktraits::compute_t; using Ivec = typename Ktraits::Ivec;
using Ivec = typename Ktraits::Ivec; using Ovec = typename Ktraits::Ovec;
using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP;
const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / WARPS_N;
const index_t warp_m = warp / WARPS_N; const index_t warp_n = warp % WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t r = bidm * ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu); compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec gamma[LDGS]; Wvec beta[LDGS];
Wvec beta[LDGS]; index_t idx = c;
index_t idx = c; #pragma unroll
#pragma unroll for (int it = 0; it < LDGS; ++it) {
for ( int it = 0; it < LDGS; ++it ) { gamma[it].load_from(params.gamma, idx);
gamma[it].load_from(params.gamma, idx); beta[it].load_from(params.beta, idx);
beta[it].load_from(params.beta, idx); idx += VEC_COLS_PER_LDG;
idx += VEC_COLS_PER_LDG; }
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale = 1.f;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
x[it].load_from(params.x, idx);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
} }
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
compute_t scale = 1.f; if (bidn == 0 && warp_n == 0 && lane == 0) {
if (params.fp8_out) { mu_ptr[row] = mu;
scale = *reinterpret_cast<compute_t*>(params.scale);
} }
compute_t amax = 0;
for ( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn); compute_t rs = rsqrtf(rn * m2 + params.epsilon);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s); if (bidn == 0 && warp_n == 0 && lane == 0) {
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s); rs_ptr[row] = rs;
}
if ( bidn == 0 && warp_n == 0 && lane == 0 ) { Ovec z[LDGS];
mu_ptr[row] = mu; idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
} }
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
compute_t rs = rsqrtf(rn * m2 + params.epsilon); if (params.fp8_out) {
__builtin_assume(amax >= 0);
if ( bidn == 0 && warp_n == 0 && lane == 0 ) { amax = fmaxf(amax, fabsf(temp_output));
rs_ptr[row] = rs; temp_output = temp_output * scale;
} }
Ovec z[LDGS]; z[it].data.elt[jt] = output_t(temp_output);
idx = row * Ktraits::VEC_COLS + c; }
#pragma unroll z[it].store_to(params.z, idx);
for ( int it = 0; it < LDGS; it++ ) { idx += VEC_COLS_PER_LDG;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
} }
if (params.fp8_out) { }
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); if (params.fp8_out) {
if (threadIdx.x == 0 && threadIdx.y == 0) { amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
static_assert(std::is_same<compute_t, float>::value); if (threadIdx.x == 0 && threadIdx.y == 0) {
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax); static_assert(std::is_same<compute_t, float>::value);
} atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
} }
}
} }
template<typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel(
void ln_fwd_general_kernel(FwdParams params) { FwdParams params) {
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t; using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t; using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t; using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec; using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N; const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N; const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M; const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP; const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row; const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row; const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col; const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row; const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m; const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
+ warp_n * params.ctas_per_row * THREADS_PER_WARP lane); // Order threads by warp x cta x lane
+ lane); // Order threads by warp x cta x lane
// Objects for stats reductions
// Objects for stats reductions using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>; constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; __shared__ char smem_[SMEM_BYTES];
__shared__ char smem_[SMEM_BYTES]; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); Sum<compute_t> sum;
Sum<compute_t> sum; const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
// Load weights Cvec gamma[LDGS];
Cvec gamma[LDGS]; Cvec beta[LDGS];
Cvec beta[LDGS]; #pragma unroll
#pragma unroll for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
for ( int it = 0, col = gidn * NUM_ELTS; ++it, col += gdimn * NUM_ELTS) {
it < LDGS && col < params.cols; Wvec gamma_in, beta_in;
++it, col += gdimn * NUM_ELTS ) { gamma_in.load_from_elts(params.gamma, col, params.cols - col);
Wvec gamma_in, beta_in; beta_in.load_from_elts(params.beta, col, params.cols - col);
gamma_in.load_from_elts(params.gamma, col, params.cols - col); gamma_in.to(gamma[it]);
beta_in.load_from_elts(params.beta, col, params.cols - col); beta_in.to(beta[it]);
gamma_in.to(gamma[it]); }
beta_in.to(beta[it]);
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
} }
// fp8 factors // Compute mean
compute_t scale; compute_t mu = 0.f;
if ( params.fp8_out ) { #pragma unroll
scale = *reinterpret_cast<compute_t*>(params.scale); for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu += x[it].data.elt[jt];
}
} }
compute_t amax = 0; mu = reducer.allreduce(mu, sum) * rn;
for ( int cta_row = bidm * bdimm; // Compute variance
cta_row < params.rows; compute_t sqsigma = 0.f;
cta_row += gdimm ) { #pragma unroll
const int row = cta_row + warp_m; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
// Load input #pragma unroll
Cvec x[LDGS]; for (int jt = 0; jt < NUM_ELTS; jt++) {
#pragma unroll if (col + jt < params.cols) {
for ( int it = 0, col = gidn * NUM_ELTS; compute_t diff = x[it].data.elt[jt] - mu;
it < LDGS && row < params.rows && col < params.cols; sqsigma += diff * diff;
it++, col += gdimn * NUM_ELTS ) {
Ivec x_in;
x_in.load_from_elts(params.x,
row * params.cols + col,
params.cols - col);
x_in.to(x[it]);
} }
}
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if (gidn == 0 && row < params.rows) {
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
mu_ptr[row] = mu;
rs_ptr[row] = rs;
}
// Compute mean // Compute output
compute_t mu = 0.f; #pragma unroll
#pragma unroll for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
for ( int it = 0, col = gidn * NUM_ELTS; it++, col += gdimn * NUM_ELTS) {
it < LDGS && row < params.rows && col < params.cols; // Compute output values
it++, col += gdimn * NUM_ELTS ) { Cvec z;
#pragma unroll #pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { for (int jt = 0; jt < NUM_ELTS; jt++) {
mu += x[it].data.elt[jt]; compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
} compute_t g_ij = gamma[it].data.elt[jt];
} if (params.zero_centered_gamma) {
mu = reducer.allreduce(mu, sum) * rn; g_ij += 1;
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t diff = x[it].data.elt[jt] - mu;
sqsigma += diff * diff;
}
}
} }
sqsigma = reducer.allreduce(sqsigma, sum) * rn; compute_t b_ij = beta[it].data.elt[jt];
compute_t rs = rsqrtf(sqsigma + params.epsilon); z.data.elt[jt] = g_ij * y_ij + b_ij;
}
// Write statistics
if ( gidn == 0 && row < params.rows ) { // Apply fp8 factors
compute_t *mu_ptr = static_cast<compute_t *>(params.mu); if (params.fp8_out) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs); #pragma unroll
mu_ptr[row] = mu; for (int jt = 0; jt < NUM_ELTS; jt++) {
rs_ptr[row] = rs; if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
} }
}
// Compute output // Store output
#pragma unroll Ovec z_out;
for ( int it = 0, col = gidn * NUM_ELTS; z.to(z_out);
it < LDGS && row < params.rows && col < params.cols; z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
it++, col += gdimn * NUM_ELTS ) {
// Compute output values
Cvec z;
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij;
}
// Apply fp8 factors
if ( params.fp8_out ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
if ( col + jt < params.cols ) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z,
row * params.cols + col,
params.cols - col);
}
} }
}
// Finalize fp8 factors
if ( params.fp8_out ) { // Finalize fp8 factors
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); if (params.fp8_out) {
if ( threadIdx.x == 0 ) { amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
static_assert(std::is_same<compute_t, float>::value); if (threadIdx.x == 0) {
atomicMaxFloat(reinterpret_cast<compute_t*>(params.amax), amax); static_assert(std::is_same<compute_t, float>::value);
} atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
} }
}
} }
} // namespace layer_norm } // namespace layer_norm
......
...@@ -14,154 +14,119 @@ ...@@ -14,154 +14,119 @@
namespace transformer_engine { namespace transformer_engine {
namespace layer_norm { namespace layer_norm {
template< template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
uint32_t HIDDEN_SIZE_, typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_>
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base { struct Kernel_traits_base {
using weight_t = weight_t_; using weight_t = weight_t_;
using input_t = input_t_; using input_t = input_t_;
using output_t = output_t_; using output_t = output_t_;
using compute_t = compute_t_; using compute_t = compute_t_;
using index_t = index_t_; using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 }; enum { THREADS_PER_WARP = 32 };
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template< template <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_,
uint32_t HIDDEN_SIZE_, typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_,
typename weight_t_, uint32_t BYTES_PER_LDG_,
typename input_t_, typename Base = Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_,
typename output_t_, compute_t_, index_t_, THREADS_PER_CTA_> >
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base { struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert(static_cast<int>(ROWS_PER_CTA) <= static_cast<int>(Base::THREADS_PER_WARP)); static_assert(static_cast<int>(ROWS_PER_CTA) <= static_cast<int>(Base::THREADS_PER_WARP));
// Bytes per global load from the input. // Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load. // Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights. // Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, static_assert(sizeof(BYTES_PER_LDG) == 4,
"Conflict-free smem transpose only implemented for 4B compute type!"); "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP,
"We assume one warp per row!"); "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector. // The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result. // Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result. // Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA. // Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer. // The type of the reducer.
using Reducer = transformer_engine::Reducer<compute_t_, 1, 1, 1>; using Reducer = transformer_engine::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads. // Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0); static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP }; enum { CTAS = COLS / Base::THREADS_PER_WARP };
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
template< typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
typename weight_t_, uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename input_t_, typename Base =
typename output_t_, Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_,
typename compute_t_, WARPS_M_ * WARPS_N_ * THREADS_PER_WARP> >
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base { struct Kernel_traits : public Base {
using input_t = typename Base::input_t; using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t; using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t; using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t; using output_t = typename Base::output_t;
using index_t = typename Base::index_t; using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ }; enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ }; enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ }; enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M }; enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename transformer_engine::TypeToVec2<compute_t>::Type; using reduce_t = typename transformer_engine::TypeToVec2<compute_t>::Type;
using Reducer = transformer_engine::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>; using Reducer = transformer_engine::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = transformer_engine::Vec<input_t, NUM_ELTS>; using Ivec = transformer_engine::Vec<input_t, NUM_ELTS>;
using Ovec = transformer_engine::Vec<output_t, NUM_ELTS>; using Ovec = transformer_engine::Vec<output_t, NUM_ELTS>;
using Wvec = transformer_engine::Vec<weight_t, NUM_ELTS>; using Wvec = transformer_engine::Vec<weight_t, NUM_ELTS>;
using Cvec = transformer_engine::Vec<compute_t, NUM_ELTS>; using Cvec = transformer_engine::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements // Assume that each thread can handle the same number of elements
// in the output and weights as in the input. // in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t)); static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t)); static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread. // The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector. // The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG }; enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input. // The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>; using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -7,19 +7,16 @@ ...@@ -7,19 +7,16 @@
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_ #ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_ #define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h> #include <nvToolsExt.h>
#include <string>
namespace transformer_engine::nvtx { namespace transformer_engine::nvtx {
struct NVTXWrapper { struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) { explicit NVTXWrapper(const std::string &name) { nvtxRangePush(name.c_str()); }
nvtxRangePush(name.c_str());
}
~NVTXWrapper() { ~NVTXWrapper() { nvtxRangePop(); }
nvtxRangePop();
}
}; };
} // namespace transformer_engine::nvtx } // namespace transformer_engine::nvtx
......
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include <transformer_engine/recipe.h> #include <transformer_engine/recipe.h>
#include <cmath> #include <cmath>
#include <string>
#include <limits> #include <limits>
#include <string>
#include "../common.h" #include "../common.h"
#include "../util/logging.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include "../util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace delayed_scaling_recipe { namespace delayed_scaling_recipe {
...@@ -24,18 +24,19 @@ enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX }; ...@@ -24,18 +24,19 @@ enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };
const char* dtype_name(DType dtype) { const char* dtype_name(DType dtype) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
return TypeInfo<Type>::name; return TypeInfo<Type>::name;); // NOLINT(*)
); // NOLINT(*)
return ""; return "";
} }
// Maximum representable value of an FP8 dtype // Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) { inline float fp8_dtype_max(DType dtype) {
switch (dtype) { switch (dtype) {
case DType::kFloat8E4M3: return 448; case DType::kFloat8E4M3:
case DType::kFloat8E5M2: return 57344; return 448;
default: case DType::kFloat8E5M2:
NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype)); return 57344;
default:
NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
} }
return 0; return 0;
} }
...@@ -58,12 +59,12 @@ struct OtherParams { ...@@ -58,12 +59,12 @@ struct OtherParams {
#if CUDART_VERSION >= 12010 #if CUDART_VERSION >= 12010
constexpr size_t max_constant_memory_per_kernel = 32768; constexpr size_t max_constant_memory_per_kernel = 32768;
constexpr size_t AMAX_PARAMS_LIMIT = ( constexpr size_t AMAX_PARAMS_LIMIT =
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#else #else
constexpr size_t max_constant_memory_per_kernel = 4096; constexpr size_t max_constant_memory_per_kernel = 4096;
constexpr size_t AMAX_PARAMS_LIMIT = ( constexpr size_t AMAX_PARAMS_LIMIT =
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#endif #endif
struct AmaxParams { struct AmaxParams {
...@@ -82,17 +83,10 @@ constexpr size_t bsize = 256; ...@@ -82,17 +83,10 @@ constexpr size_t bsize = 256;
* Grid dims: num_scales x 1 x 1 * Grid dims: num_scales x 1 x 1
*/ */
__global__ void __launch_bounds__(bsize) __global__ void __launch_bounds__(bsize)
kernel(const float* amax_history_ptr, kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr,
const float* scale_ptr, const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr,
const float* scale_inv_ptr, float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length,
const unsigned char* scale_inv_mask_ptr, size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) {
float* updated_amax_history_ptr,
float* updated_scale_ptr,
float* updated_scale_inv_ptr,
size_t amax_history_length,
size_t amax_history_stride,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
const size_t tid = threadIdx.x; const size_t tid = threadIdx.x;
const size_t bid = blockIdx.x; const size_t bid = blockIdx.x;
...@@ -109,22 +103,21 @@ kernel(const float* amax_history_ptr, ...@@ -109,22 +103,21 @@ kernel(const float* amax_history_ptr,
const size_t i = off + tid; const size_t i = off + tid;
float a = 0; float a = 0;
if (i < length) { if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
amax = fmaxf(amax, a); amax = fmaxf(amax, a);
} }
__syncthreads(); // In case roll is in-place __syncthreads(); // In case roll is in-place
if (i < length) { if (i < length) {
updated_amax_history[i*stride] = (i > 0) ? a : 0; updated_amax_history[i * stride] = (i > 0) ? a : 0;
} }
} }
// Compute amax to use for scaling factor // Compute amax to use for scaling factor
switch (amax_compute_algo) { switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT: case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax; amax = last_amax;
break; break;
case AmaxComputeAlgo::MAX: case AmaxComputeAlgo::MAX: {
{
__shared__ float shared_amax[bsize]; __shared__ float shared_amax[bsize];
shared_amax[tid] = amax; shared_amax[tid] = amax;
__syncthreads(); __syncthreads();
...@@ -136,10 +129,9 @@ kernel(const float* amax_history_ptr, ...@@ -136,10 +129,9 @@ kernel(const float* amax_history_ptr,
__syncthreads(); __syncthreads();
} }
amax = shared_amax[tid]; amax = shared_amax[tid];
} } break;
break; default:
default: amax = 0;
amax = 0;
} }
} }
...@@ -157,7 +149,7 @@ kernel(const float* amax_history_ptr, ...@@ -157,7 +149,7 @@ kernel(const float* amax_history_ptr,
// amax won't get mapped to the FP8 max representable, but rather // amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do. // something below that, but this is the best thing we can do.
if (isinf(scale)) { if (isinf(scale)) {
scale = std::numeric_limits<float>::max(); scale = std::numeric_limits<float>::max();
} }
updated_scale_ptr[bid] = scale; updated_scale_ptr[bid] = scale;
...@@ -179,12 +171,8 @@ kernel(const float* amax_history_ptr, ...@@ -179,12 +171,8 @@ kernel(const float* amax_history_ptr,
* Grid dims: num_tensors x 1 x 1 * Grid dims: num_tensors x 1 x 1
*/ */
__global__ void __launch_bounds__(bsize) __global__ void __launch_bounds__(bsize)
kernel_bulk( kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length,
float* amax_reduction_buffer, AmaxComputeAlgo amax_compute_algo, float scaled_max) {
AmaxParams p,
size_t amax_history_length,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
const size_t bid = blockIdx.x; const size_t bid = blockIdx.x;
const size_t tid = threadIdx.x; const size_t tid = threadIdx.x;
const int num_scale = p.param[bid].num_scale; const int num_scale = p.param[bid].num_scale;
...@@ -201,32 +189,32 @@ kernel_bulk( ...@@ -201,32 +189,32 @@ kernel_bulk(
// Roll amax history // Roll amax history
const auto& length = amax_history_length; const auto& length = amax_history_length;
const auto& stride = p.param[bid].num_scale; const auto& stride = p.param[bid].num_scale;
auto* amax_history = p.param[bid].amax_history+count; auto* amax_history = p.param[bid].amax_history + count;
const auto last_amax = ((amax_reduction_buffer != nullptr) const auto last_amax = ((amax_reduction_buffer != nullptr) &&
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? (amax_reduction_buffer[offset_in_buffer + count] != 0.0f))
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; ? amax_reduction_buffer[offset_in_buffer + count]
: amax_history[0];
if (last_amax != 0.0f) { if (last_amax != 0.0f) {
for (size_t off = 0; off < length; off += bsize) { for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid; const size_t i = off + tid;
float a = 0; float a = 0;
if (i < length) { if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
amax = fmaxf(amax, a); amax = fmaxf(amax, a);
} }
__syncthreads(); // Inplace roll __syncthreads(); // Inplace roll
if (i < length) { if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0; amax_history[i * stride] = (i > 0) ? a : 0;
} }
} }
} }
// Compute amax to use for scaling factor // Compute amax to use for scaling factor
switch (amax_compute_algo) { switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT: case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax; amax = last_amax;
break; break;
case AmaxComputeAlgo::MAX: case AmaxComputeAlgo::MAX: {
{
__shared__ float shared_amax[bsize]; __shared__ float shared_amax[bsize];
shared_amax[tid] = amax; shared_amax[tid] = amax;
__syncthreads(); __syncthreads();
...@@ -238,10 +226,9 @@ kernel_bulk( ...@@ -238,10 +226,9 @@ kernel_bulk(
__syncthreads(); __syncthreads();
} }
amax = shared_amax[tid]; amax = shared_amax[tid];
} } break;
break; default:
default: amax = 0;
amax = 0;
} }
} }
...@@ -269,7 +256,7 @@ kernel_bulk( ...@@ -269,7 +256,7 @@ kernel_bulk(
// amax won't get mapped to the FP8 max representable, but rather // amax won't get mapped to the FP8 max representable, but rather
// something below that, but this is the best thing we can do. // something below that, but this is the best thing we can do.
if (isinf(scale)) { if (isinf(scale)) {
scale = std::numeric_limits<float>::max(); scale = std::numeric_limits<float>::max();
} }
p.param[bid].scale[count] = scale; p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale; p.param[bid].scale_inv[count] = 1 / scale;
...@@ -281,24 +268,17 @@ kernel_bulk( ...@@ -281,24 +268,17 @@ kernel_bulk(
} // namespace } // namespace
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv,
void amax_and_scale_update(const Tensor &amax_history, const Tensor& scale_inv_mask, Tensor* updated_amax_history_,
const Tensor &scale, Tensor* updated_scale_, Tensor* updated_scale_inv_,
const Tensor &scale_inv, const std::string& amax_compute_algo, DType fp8_dtype, float margin,
const Tensor &scale_inv_mask,
Tensor *updated_amax_history_,
Tensor *updated_scale_,
Tensor *updated_scale_inv_,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
cudaStream_t stream) { cudaStream_t stream) {
auto& updated_amax_history = *updated_amax_history_; auto& updated_amax_history = *updated_amax_history_;
auto& updated_scale = *updated_scale_; auto& updated_scale = *updated_scale_;
auto& updated_scale_inv = *updated_scale_inv_; auto& updated_scale_inv = *updated_scale_inv_;
// Number of elements in tensor // Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t { auto numel = [](const Tensor& tensor) -> size_t {
size_t acc = 1; size_t acc = 1;
for (const auto& dim : tensor.data.shape) { for (const auto& dim : tensor.data.shape) {
acc *= dim; acc *= dim;
...@@ -307,48 +287,40 @@ void amax_and_scale_update(const Tensor &amax_history, ...@@ -307,48 +287,40 @@ void amax_and_scale_update(const Tensor &amax_history,
}; };
// Check tensors // Check tensors
NVTE_CHECK(amax_history.data.shape.size() == 2, NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
"Found ", amax_history.data.shape.size(), " dims"); " dims");
const size_t amax_history_length = amax_history.data.shape[0]; const size_t amax_history_length = amax_history.data.shape[0];
const size_t num_scales = amax_history.data.shape[1]; const size_t num_scales = amax_history.data.shape[1];
NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
"Found ", dtype_name(amax_history.data.dtype), "."); dtype_name(amax_history.data.dtype), ".");
NVTE_CHECK(numel(scale) == num_scales, NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
"Expected ", num_scales, " elements, ", numel(scale), ".");
"but found ", numel(scale), "."); NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
NVTE_CHECK(scale.data.dtype == DType::kFloat32,
"Found ", dtype_name(scale.data.dtype), ".");
if (scale_inv_mask.data.dptr != nullptr) { if (scale_inv_mask.data.dptr != nullptr) {
NVTE_CHECK(numel(scale_inv) == num_scales, NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
"Expected ", num_scales, " elements, ", numel(scale_inv), ".");
"but found ", numel(scale_inv), ".");
NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32);
NVTE_CHECK(numel(scale_inv_mask) == num_scales, NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ",
"Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv_mask), "."); "but found ", numel(scale_inv_mask), ".");
NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ",
"Found ", dtype_name(scale_inv_mask.data.dtype), "."); dtype_name(scale_inv_mask.data.dtype), ".");
} }
NVTE_CHECK(updated_amax_history.data.shape.size() == 2, NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
"Found ", updated_amax_history.data.shape.size(), " dims."); updated_amax_history.data.shape.size(), " dims.");
NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
"Expected ", amax_history_length, ", ", amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]);
"but found ", updated_amax_history.data.shape[0]); NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ",
NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales,
"Expected ", num_scales, ", ",
"but found ", updated_amax_history.data.shape[1]); "but found ", updated_amax_history.data.shape[1]);
NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
"Got ", dtype_name(updated_amax_history.data.dtype), "."); dtype_name(updated_amax_history.data.dtype), ".");
NVTE_CHECK(numel(updated_scale) == num_scales, NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ",
"Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale), "."); "but found ", numel(updated_scale), ".");
NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
"Got ", dtype_name(updated_scale.data.dtype), "."); dtype_name(updated_scale.data.dtype), ".");
NVTE_CHECK(numel(updated_scale_inv) == num_scales, NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ",
"Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale_inv), "."); "but found ", numel(updated_scale_inv), ".");
NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ",
"Got ", dtype_name(updated_scale_inv.data.dtype), "."); dtype_name(updated_scale_inv.data.dtype), ".");
// amax value to use for updating scaling factor // amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
...@@ -366,31 +338,23 @@ void amax_and_scale_update(const Tensor &amax_history, ...@@ -366,31 +338,23 @@ void amax_and_scale_update(const Tensor &amax_history,
// Launch CUDA kernel // Launch CUDA kernel
constexpr size_t block_size = amax_and_scale_update_impl::bsize; constexpr size_t block_size = amax_and_scale_update_impl::bsize;
const size_t grid_size = num_scales; const size_t grid_size = num_scales;
amax_and_scale_update_impl::kernel amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
<<<grid_size, block_size, 0, stream>>>( static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
static_cast<const float*>(amax_history.data.dptr),
static_cast<const float*>(scale.data.dptr),
static_cast<const float*>(scale_inv.data.dptr), static_cast<const float*>(scale_inv.data.dptr),
static_cast<const unsigned char*>(scale_inv_mask.data.dptr), static_cast<const unsigned char*>(scale_inv_mask.data.dptr),
static_cast<float*>(updated_amax_history.data.dptr), static_cast<float*>(updated_amax_history.data.dptr),
static_cast<float*>(updated_scale.data.dptr), static_cast<float*>(updated_scale.data.dptr),
static_cast<float*>(updated_scale_inv.data.dptr), static_cast<float*>(updated_scale_inv.data.dptr), amax_history_length, num_scales,
amax_history_length, amax_compute_algo_, scaled_max);
num_scales,
amax_compute_algo_,
scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
std::vector<Tensor*> amax_histories, std::vector<Tensor*> amax_histories,
std::vector<Tensor*> scales, std::vector<Tensor*> scales,
std::vector<Tensor*> scale_invs, std::vector<Tensor*> scale_invs,
const std::string &amax_compute_algo, const std::string& amax_compute_algo, DType fp8_dtype,
DType fp8_dtype, float margin, cudaStream_t stream) {
float margin,
cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
// amax value to use for updating scaling factor // amax value to use for updating scaling factor
...@@ -407,7 +371,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, ...@@ -407,7 +371,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Number of elements in tensor // Number of elements in tensor
auto numel = [] (const Tensor *tensor) -> size_t { auto numel = [](const Tensor* tensor) -> size_t {
size_t acc = 1; size_t acc = 1;
for (const auto& dim : tensor->data.shape) { for (const auto& dim : tensor->data.shape) {
acc *= dim; acc *= dim;
...@@ -418,7 +382,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, ...@@ -418,7 +382,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
// Number of tensors in the bulk // Number of tensors in the bulk
const size_t num_tensors = amax_histories.size(); const size_t num_tensors = amax_histories.size();
size_t num_remaining_tensors = num_tensors; size_t num_remaining_tensors = num_tensors;
const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT;
size_t amax_history_length = 0; size_t amax_history_length = 0;
if (num_tensors > 0) { if (num_tensors > 0) {
amax_history_length = amax_histories[0]->data.shape[0]; amax_history_length = amax_histories[0]->data.shape[0];
...@@ -429,27 +393,26 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, ...@@ -429,27 +393,26 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
AmaxParams p; AmaxParams p;
for (int iter = 0; iter < num_kernels; iter++) { for (int iter = 0; iter < num_kernels; iter++) {
size_t kernel_num_scales = 0; size_t kernel_num_scales = 0;
size_t kernel_num_tensors = (iter == (num_kernels - 1)) size_t kernel_num_tensors =
? num_remaining_tensors: AMAX_PARAMS_LIMIT; (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT;
for (size_t pi = 0; pi < kernel_num_tensors; pi++) { for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
size_t i = iter * AMAX_PARAMS_LIMIT + pi; size_t i = iter * AMAX_PARAMS_LIMIT + pi;
// Check tensors // Check tensors
int num_scale = amax_histories[i]->data.shape[1]; int num_scale = amax_histories[i]->data.shape[1];
NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ",
"Found ", dtype_name(amax_histories[i]->data.dtype), "."); dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
"Found ", amax_histories[i]->data.shape.size(), " dims"); amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ",
"Expected ", amax_history_length * num_scale, " elements, ", amax_history_length * num_scale, " elements, ", "but found ",
"but found ", numel(amax_histories[i]), "."); numel(amax_histories[i]), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
"Found ", dtype_name(scales[i]->data.dtype), "."); dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1, NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
"Found ", scales[i]->data.shape.size(), " dims"); " dims");
NVTE_CHECK(numel(scales[i]) == num_scale, NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ",
"Expected ", num_scale, " elements, ", numel(scales[i]), ".");
"Found ", numel(scales[i]), ".");
// amax parameters // amax parameters
kernel_num_scales += num_scale; kernel_num_scales += num_scale;
...@@ -462,13 +425,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, ...@@ -462,13 +425,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
// Launch CUDA kernel // Launch CUDA kernel
size_t grid_size = kernel_num_tensors; size_t grid_size = kernel_num_tensors;
const size_t block_size = amax_and_scale_update_impl::bsize; const size_t block_size = amax_and_scale_update_impl::bsize;
amax_and_scale_update_impl::kernel_bulk amax_and_scale_update_impl::kernel_bulk<<<grid_size, block_size, 0, stream>>>(
<<<grid_size, block_size, 0, stream>>>( amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max);
amax_buffer,
p,
amax_history_length,
amax_compute_algo_,
scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
// shift amax buffer pointer // shift amax buffer pointer
...@@ -482,44 +440,25 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, ...@@ -482,44 +440,25 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
} // namespace delayed_scaling_recipe } // namespace delayed_scaling_recipe
} // namespace transformer_engine } // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale, const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
const NVTETensor scale_inv, NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
const NVTETensor scale_inv_mask, cudaStream_t stream) {
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine; using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update( delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
*reinterpret_cast<const Tensor*>(scale), *reinterpret_cast<const Tensor*>(scale_inv), *reinterpret_cast<const Tensor*>(scale_inv_mask),
*reinterpret_cast<const Tensor*>(scale_inv), reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
*reinterpret_cast<const Tensor*>(scale_inv_mask), reinterpret_cast<Tensor*>(updated_scale_inv), amax_compute_algo,
reinterpret_cast<Tensor*>(updated_amax_history), static_cast<DType>(fp8_dtype), margin, stream);
reinterpret_cast<Tensor*>(updated_scale),
reinterpret_cast<Tensor*>(updated_scale_inv),
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
} }
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer, const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> amax_histories, std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) {
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
using namespace transformer_engine; using namespace transformer_engine;
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
...@@ -530,12 +469,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( ...@@ -530,12 +469,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i])); t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
} }
delayed_scaling_recipe::amax_and_scale_update_after_reduction( delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), *reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
t_amax_histories, t_scale_invs, amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
t_scales,
t_scale_invs,
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
} }
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
...@@ -15,7 +16,6 @@ ...@@ -15,7 +16,6 @@
#include <vector> #include <vector>
#include "../common.h" #include "../common.h"
#include "../layer_norm/ln.h" #include "../layer_norm/ln.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -47,40 +47,40 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS; ...@@ -47,40 +47,40 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar { struct FwdTunedRegistrar {
explicit FwdTunedRegistrar(FwdFunction f) { explicit FwdTunedRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({key, f}); FWD_TUNED_FUNCS.insert({key, f});
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar { struct FwdGeneralRegistrar {
explicit FwdGeneralRegistrar(FwdFunction f) { explicit FwdGeneralRegistrar(FwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0); uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar { struct BwdTunedRegistrar {
explicit BwdTunedRegistrar(BwdFunction f) { explicit BwdTunedRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({key, f}); BWD_TUNED_FUNCS.insert({key, f});
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////
template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar { struct BwdGeneralRegistrar {
explicit BwdGeneralRegistrar(BwdFunction f) { explicit BwdGeneralRegistrar(BwdFunction f) {
uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0); uint64_t key = layer_norm::Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f});
} }
}; };
} // namespace rmsnorm } // namespace rmsnorm
......
...@@ -4,14 +4,13 @@ ...@@ -4,14 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "transformer_engine/rmsnorm.h"
#include <cstdint> #include <cstdint>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "rmsnorm.h"
#include "../common.h" #include "../common.h"
#include "rmsnorm.h"
#include "transformer_engine/rmsnorm.h"
/* /*
...@@ -49,85 +48,70 @@ BwdTunedRegistry BWD_TUNED_FUNCS; ...@@ -49,85 +48,70 @@ BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry FWD_GENERAL_FUNCS; FwdGeneralRegistry FWD_GENERAL_FUNCS;
BwdGeneralRegistry BWD_GENERAL_FUNCS; BwdGeneralRegistry BWD_GENERAL_FUNCS;
FwdFunction &get_fwd_launcher(DType wtype, FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
DType itype,
DType otype,
DType ctype,
const layer_norm::FwdParams &params) { const layer_norm::FwdParams &params) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool { auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B // Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0; return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
}; };
if (params.rows % 4 == 0 if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
&& is_aligned(params.x) is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) {
&& is_aligned(params.rs) return FWD_TUNED_FUNCS.at(tuned_key);
&& is_aligned(params.gamma) }
&& is_aligned(params.z)
&& FWD_TUNED_FUNCS.count(tuned_key) > 0) { // Pick general kernel
return FWD_TUNED_FUNCS.at(tuned_key); auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
} if (FWD_GENERAL_FUNCS.count(general_key) == 0) {
NVTE_ERROR("FWD: Unsupported types.");
// Pick general kernel }
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key);
if (FWD_GENERAL_FUNCS.count(general_key) == 0) { auto func_iter = general_func_map.lower_bound(params.cols);
NVTE_ERROR("FWD: Unsupported types."); if (func_iter == general_func_map.end()) {
} // Hidden size is too big, need to use multi-CTA
auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); return general_func_map.rbegin()->second;
auto func_iter = general_func_map.lower_bound(params.cols); } else {
if (func_iter == general_func_map.end()) { return func_iter->second;
// Hidden size is too big, need to use multi-CTA }
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction &get_bwd_launcher(DType wtype, BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype,
DType itype,
DType otype,
DType ctype,
const layer_norm::BwdParams &params) { const layer_norm::BwdParams &params) {
// Look for tuned kernel // Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool { auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B // Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0; return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
}; };
if (params.rows % 4 == 0 if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
&& is_aligned(params.x) is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) &&
&& is_aligned(params.rs) is_aligned(params.dgamma) && is_aligned(params.dgamma_part) &&
&& is_aligned(params.gamma) layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
&& is_aligned(params.dz) return BWD_TUNED_FUNCS.at(tuned_key);
&& is_aligned(params.dx) }
&& is_aligned(params.dgamma)
&& is_aligned(params.dgamma_part) // Pick general kernel
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0);
return BWD_TUNED_FUNCS.at(tuned_key); if (BWD_GENERAL_FUNCS.count(general_key) == 0) {
} NVTE_ERROR("BWD: Unsupported types.");
}
// Pick general kernel auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key);
auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); auto func_iter = general_func_map.lower_bound(params.cols);
if (BWD_GENERAL_FUNCS.count(general_key) == 0) { if (func_iter == general_func_map.end()) {
NVTE_ERROR("BWD: Unsupported types."); // Hidden size is too big, need to use multi-CTA
} return general_func_map.rbegin()->second;
auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); } else {
auto func_iter = general_func_map.lower_bound(params.cols); return func_iter->second;
if (func_iter == general_func_map.end()) { }
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
} else {
return func_iter->second;
}
} }
// //////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) { inline size_t product(const std::vector<size_t> &shape) {
return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
} }
} // namespace rmsnorm } // namespace rmsnorm
...@@ -137,213 +121,211 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -137,213 +121,211 @@ inline size_t product(const std::vector<size_t> &shape) {
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) {
auto itype = x.data.dtype; auto itype = x.data.dtype;
auto wtype = gamma.data.dtype; auto wtype = gamma.data.dtype;
auto otype = z->data.dtype; auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype); const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
const size_t rows = x.data.shape[0]; const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1]; const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0]; const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(hidden_size == cols); NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows}); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{rows});
NVTE_CHECK(rsigma->data.dtype == ctype); NVTE_CHECK(rsigma->data.dtype == ctype);
rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params; rmsnorm::LaunchParams<rmsnorm::FwdParams> launch_params;
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream; launch_params.stream = stream;
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
rmsnorm::FwdParams &params = launch_params.params; rmsnorm::FwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.data.dptr; params.x = x.data.dptr;
params.mu = nullptr; params.mu = nullptr;
params.rs = rsigma->data.dptr; params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr; params.gamma = gamma.data.dptr;
params.beta = nullptr; params.beta = nullptr;
params.z = z->data.dptr; params.z = z->data.dptr;
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher. // Request the kernel launcher.
auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) { if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1; launch_params.workspace_bytes = 1;
} }
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr); NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes}; workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32; barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size}; barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return; return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");
if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
}
// Launch the kernel.
launcher(launch_params, false);
return;
} }
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier, const int multiprocessorCount, Tensor *workspace, Tensor *barrier,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine; using namespace transformer_engine;
auto itype = x.data.dtype; auto itype = x.data.dtype;
auto wtype = gamma.data.dtype; auto wtype = gamma.data.dtype;
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape); NVTE_CHECK(dz.data.shape == x.data.shape);
const auto rows = x.data.shape[0]; const auto rows = x.data.shape[0];
const auto cols = x.data.shape[1]; const auto cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0]; const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(gamma.data.shape[0] == cols); NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->data.shape == x.data.shape); NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype); NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params; rmsnorm::LaunchParams<rmsnorm::BwdParams> launch_params;
launch_params.stream = stream; launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount; launch_params.multiprocessorCount = multiprocessorCount;
// Set the kernel runtime parameters. // Set the kernel runtime parameters.
rmsnorm::BwdParams &params = launch_params.params; rmsnorm::BwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.data.dptr; params.x = x.data.dptr;
params.mu = nullptr; params.mu = nullptr;
params.rs = rsigma.data.dptr; params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr; params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr; params.dz = dz.data.dptr;
params.dx = dx->data.dptr; params.dx = dx->data.dptr;
params.dbeta = nullptr; params.dbeta = nullptr;
params.dgamma = dgamma->data.dptr; params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr; params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr; params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma; params.zero_centered_gamma = zero_centered_gamma;
// Request the kernel launcher. // Request the kernel launcher.
auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params);
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory // Populate shape and dtypes for FW to allocate memory
if (dgamma_part->data.dptr == nullptr) { if (dgamma_part->data.dptr == nullptr) {
dgamma_part->data.dtype = ctype; dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col), dgamma_part->data.shape = {static_cast<uint64_t>(launch_params.params.ctas_per_col),
hidden_size}; hidden_size};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
workspace->data.shape = {launch_params.workspace_bytes}; workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32; barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size}; barrier->data.shape = {launch_params.barrier_size};
return; return;
} else { } else {
auto pdw_shape = std::vector<size_t>{ auto pdw_shape =
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size}; std::vector<size_t>{static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype); NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape); NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
} }
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr); NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32); NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size }); NVTE_CHECK(barrier->data.shape == std::vector<size_t>{launch_params.barrier_size});
} }
if (launch_params.workspace_bytes > 0) { if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr); NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte); NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes }); NVTE_CHECK(workspace->data.shape == std::vector<size_t>{launch_params.workspace_bytes});
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data // Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma"); CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx"); CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma"); CheckOutputTensor(*dgamma, "dgamma");
if (launch_params.barrier_size > 0) { if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr); params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0, cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream); stream);
} }
// Launch the kernel. // Launch the kernel.
launcher(launch_params, false); launcher(launch_params, false);
} }
} // namespace transformer_engine } // namespace transformer_engine
...@@ -364,18 +346,15 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -364,18 +346,15 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32! const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd); NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount, reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), false);
false);
} }
void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size
...@@ -394,9 +373,8 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -394,9 +373,8 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32! const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part,
NVTETensor dgamma_part, cudaStream_t stream, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_bwd); NVTE_API_CALL(nvte_rmsnorm1p_bwd);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -404,6 +382,5 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -404,6 +382,5 @@ void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount, reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier), true);
true);
} }
...@@ -16,466 +16,462 @@ using namespace transformer_engine; ...@@ -16,466 +16,462 @@ using namespace transformer_engine;
template <typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BwdParams params) { BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS }; enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec; using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer; using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type; using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[]; extern __shared__ char smem_[];
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N; const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N; const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane; const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS]; Cvec dzy_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dzy_sum, 0, sizeof(dzy_sum));
compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_); compute_t *smem_wgrad = reinterpret_cast<compute_t *>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum; Sum<reduce_t> sum;
constexpr float rn = 1.f / static_cast<float>(COLS); constexpr float rn = 1.f / static_cast<float>(COLS);
Wvec gamma[LDGS]; Wvec gamma[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads! // last blocks with syncthreads!
// grid stride over rows // grid stride over rows
#pragma unroll 1 #pragma unroll 1
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row]; const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS]; Ivec x[LDGS];
Ovec dz[LDGS]; Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c; index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
dz[it].load_from(params.dz, idx); dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx); x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
compute_t dy[LDGS * NUM_ELTS]; compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS];
compute_t mdyy_local = 0.f; compute_t mdyy_local = 0.f;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt]; compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp); compute_t y_tmp = rs_r * (x_tmp);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift; compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]); dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt]; compute_t dz_tmp = dz[it].data.elt[jt];
mdyy_local += dy_tmp * y_tmp; mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp; dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp; y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
} }
} }
reduce_t result = reducer.allreduce({0, mdyy_local}, sum); reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn; mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS]; Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c; idx = row * Ktraits::VEC_COLS + c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp)); compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp; dx[it].data.elt[jt] = dx_tmp;
} }
dx[it].store_to(params.dx, idx); dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
} // end: grid stride loop } // end: grid stride loop
if (WARPS_M == 1) { if (WARPS_M == 1) {
idx = r * Ktraits::VEC_COLS + c; idx = r * Ktraits::VEC_COLS + c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(params.dgamma_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
} else { } else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1,
"Multiple rows per CTA not supported for Multi-CTA."); "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA // Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps // by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size. // Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r; idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
dzy_sum[it].store_to(smem_wgrad, idx); dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW; idx += THREADS_PER_ROW;
} }
__syncthreads(); __syncthreads();
compute_t cta_dzy_sum[NUM_RES]; compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) { for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
} }
} }
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx; compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dzy_sum[jt]; *dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA; dgamma_part += Ktraits::THREADS_PER_CTA;
}
} }
}
} }
template <typename Kernel_traits> template <typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel( __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel(
BwdParams params) { BwdParams params) {
using compute_t = typename Kernel_traits::compute_t; using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t; using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t; using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer; using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type; using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum; Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0; constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x; const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x; const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP; const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP; const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
col += COL_STRIDE, col_out += COL_STRIDE / 2) { col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns. // Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local; Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dgamma_local, 0, sizeof(dgamma_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col; index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dgamma_part; Vec<compute_t, NUM_ELT> dgamma_part;
dgamma_part.load_from(params.dgamma_part, idx); dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll #pragma unroll
for (int it = 0; it < NUM_ELT; it++) { for (int it = 0; it < NUM_ELT; it++) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
} }
} }
void *smem_gamma = smem_; void *smem_gamma = smem_;
const int write_row = warp; const int write_row = warp;
const int write_col = lane ^ write_row; const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col; const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx); dgamma_local.store_to(smem_gamma, write_idx);
__syncthreads(); __syncthreads();
// It would be probably safe to reuse the first row of smem_gamma // It would be probably safe to reuse the first row of smem_gamma
void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
// More than one iter iff ROWS_PER_CTA < 32. // More than one iter iff ROWS_PER_CTA < 32.
for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {
const int read_row = lane; const int read_row = lane;
const int read_col = w ^ read_row; const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col; const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load gamma transposed // Load gamma transposed
if (read_row < Kernel_traits::ROWS_PER_CTA) { if (read_row < Kernel_traits::ROWS_PER_CTA) {
dgamma_local.load_from(smem_gamma, read_idx); dgamma_local.load_from(smem_gamma, read_idx);
} }
// Call reducer on the loaded value(s) and convert. // Call reducer on the loaded value(s) and convert.
#pragma unroll #pragma unroll
for (int it = 0; it < NUM_ELT; it++) { for (int it = 0; it < NUM_ELT; it++) {
compute_t g_i = dgamma_local.data.elt[it]; compute_t g_i = dgamma_local.data.elt[it];
g_i = reducer.allreduce(g_i, sum); g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i; dgamma_local.data.elt[it] = g_i;
} }
// Leader stores the result at the current column. // Leader stores the result at the current column.
if (lane == 0) { if (lane == 0) {
dgamma_local.store_to(smem_gamma_out, w); dgamma_local.store_to(smem_gamma_out, w);
} }
} }
// All writes done. // All writes done.
__syncthreads(); __syncthreads();
// Pack and store: 2-wide stores with half the threads. // Pack and store: 2-wide stores with half the threads.
if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {
using src_t = typename TypeToVec2<compute_t>::Type; using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type; using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dgamma_vec2; Vec<src_t, NUM_ELT> dgamma_vec2;
Vec<dst_t, NUM_ELT> dgamma_out2; Vec<dst_t, NUM_ELT> dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane); dgamma_vec2.load_from(smem_gamma_out, lane);
#pragma unroll #pragma unroll
for (int it = 0; it < NUM_ELT; it++) { for (int it = 0; it < NUM_ELT; it++) {
dgamma_out2.data.elt[it] = dgamma_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);
Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]); }
} dgamma_out2.store_to(params.dgamma, col_out);
dgamma_out2.store_to(params.dgamma, col_out);
}
} }
}
} }
template <typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BwdParams params) { BwdParams params) {
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t; using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t; using weight_t = typename Ktraits::weight_t;
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using output_t = typename Ktraits::output_t; using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec; using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N; const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N; const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M; const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP; const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row; const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row; const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col; const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row; const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m; const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + lane); // Order threads by warp x cta x lane
lane); // Order threads by warp x cta x lane
// Objects for weight grads
// Objects for weight grads Cvec dzy_sum[LDGS];
Cvec dzy_sum[LDGS]; memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dzy_sum, 0, sizeof(dzy_sum));
// Objects for stats reductions
// Objects for stats reductions using reduce_t = typename Ktraits::Reducer::Type;
using reduce_t = typename Ktraits::Reducer::Type; using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>;
using Reducer = DynamicReducer<reduce_t, WARPS_M, WARPS_N>; constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; __shared__ char smem_[SMEM_BYTES];
__shared__ char smem_[SMEM_BYTES]; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); Sum<reduce_t> sum;
Sum<reduce_t> sum; const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
// Load weights Cvec gamma[LDGS];
Cvec gamma[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in; Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col); gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]); gamma_in.to(gamma[it]);
}
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
compute_t rs = 0.f;
if (row < params.rows) {
rs = static_cast<const compute_t *>(params.rs)[row];
} }
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { Cvec dy[LDGS];
const int row = cta_row + warp_m; Cvec y[LDGS];
compute_t rs = 0.f; compute_t mdy = 0.f;
if (row < params.rows) { compute_t mdyy = 0.f;
rs = static_cast<const compute_t *>(params.rs)[row];
}
Cvec dy[LDGS];
Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
Ivec x; Ivec x;
Ovec dz; Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col); x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt]; compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij); compute_t y_ij = rs * (x_ij);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f; const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift; compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dz_ij = dz.data.elt[jt]; compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij; compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij; y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij; dy[it].data.elt[jt] = dy_ij;
mdy += dy_ij; mdy += dy_ij;
mdyy += dy_ij * y_ij; mdyy += dy_ij * y_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij; dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
} }
} }
// Reduce over row // Reduce over row
reduce_t result = reducer.allreduce({mdy, mdyy}, sum); reduce_t result = reducer.allreduce({mdy, mdyy}, sum);
mdy = Get<0>::of<reduce_t, compute_t>(result) * rn; mdy = Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn; mdyy = Get<1>::of<reduce_t, compute_t>(result) * rn;
// Compute dx // Compute dx
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
Ivec dx; Ivec dx;
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt]; compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt]; compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij)); dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
} }
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
} }
}
if constexpr (WARPS_M == 1) { if constexpr (WARPS_M == 1) {
// Write out local weight grad contributions // Write out local weight grad contributions
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
params.cols - col); }
} } else {
} else { // Reduce weight grad contributions within CTA before writing
// Reduce weight grad contributions within CTA before writing __shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
__shared__ Cvec vecs_shared[LDGS][WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
// Reduce dzy // Reduce dzy
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
if (it != warp_m) { if (it != warp_m) {
dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]); dzy_sum[it].store_to(&vecs_shared[it][warp_m][warp_n][lane]);
} }
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols; for (int it = warp_m, col = (gidn + it * gdimn) * NUM_ELTS; it < LDGS && col < params.cols;
it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) { it += WARPS_M, col += WARPS_M * gdimn * NUM_ELTS) {
#pragma unroll #pragma unroll
for (int kt = 0; kt < WARPS_M; kt++) { for (int kt = 0; kt < WARPS_M; kt++) {
if (kt != warp_m) { if (kt != warp_m) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt]; dzy_sum[it].data.elt[jt] += vecs_shared[it][kt][warp_n][lane].data.elt[jt];
} }
}
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col,
params.cols - col);
} }
}
dzy_sum[it].store_to_elts(params.dgamma_part, bidm * params.cols + col, params.cols - col);
} }
}
} }
template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N, template <typename weight_t, typename compute_t, uint32_t WARPS_M, uint32_t WARPS_N,
uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP> uint32_t BYTES_PER_LDG, uint32_t THREADS_PER_WARP>
__global__ __launch_bounds__( __global__ __launch_bounds__(
WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) { WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) {
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) };
using Wvec = Vec<weight_t, NUM_ELTS>; using Wvec = Vec<weight_t, NUM_ELTS>;
using Cvec = Vec<compute_t, NUM_ELTS>; using Cvec = Vec<compute_t, NUM_ELTS>;
const int lane = threadIdx.x % THREADS_PER_WARP; const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp_m = threadIdx.y; const int warp_m = threadIdx.y;
const int warp_n = threadIdx.x / THREADS_PER_WARP; const int warp_n = threadIdx.x / THREADS_PER_WARP;
const int col = blockIdx.x * blockDim.x + threadIdx.x; const int col = blockIdx.x * blockDim.x + threadIdx.x;
// Load grad contributions and accumulate locally // Load grad contributions and accumulate locally
Cvec dgamma; Cvec dgamma;
dgamma.clear(); dgamma.clear();
for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) { for (int row = warp_m; row < params.ctas_per_col && col < params.cols; row += WARPS_M) {
Cvec dgamma_part; Cvec dgamma_part;
dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col); dgamma_part.load_from_elts(params.dgamma_part, row * params.cols + col, params.cols - col);
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
dgamma.data.elt[jt] += dgamma_part.data.elt[jt]; dgamma.data.elt[jt] += dgamma_part.data.elt[jt];
}
} }
}
// Reduce dgamma within CTA // Reduce dgamma within CTA
__shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1]; __shared__ Cvec vecs_shared[WARPS_M][WARPS_N][THREADS_PER_WARP + 1];
dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]); dgamma.store_to(&vecs_shared[warp_m][warp_n][lane]);
#pragma unroll #pragma unroll
for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) { for (int nrows = WARPS_M / 2; nrows > 0; nrows /= 2) {
__syncthreads(); __syncthreads();
if (warp_m < nrows) { if (warp_m < nrows) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
vecs_shared[warp_m][warp_n][lane].data.elt[jt] += vecs_shared[warp_m][warp_n][lane].data.elt[jt] +=
vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt]; vecs_shared[warp_m + nrows][warp_n][lane].data.elt[jt];
} }
}
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
} }
}
if (warp_m == 0 && col < params.cols) {
Wvec dgamma_out;
vecs_shared[warp_m][warp_n][lane].to(dgamma_out);
dgamma_out.store_to_elts(params.dgamma, col, params.cols - col);
}
} }
} // namespace rmsnorm } // namespace rmsnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment