Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -36,7 +36,8 @@ enum class CommOverlapAlgo { ...@@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P = 4, SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6, ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7 ATOMIC_GEMM_RS_P2P = 7,
EXTERNAL_BULK_OVERLAP_AG = 8,
}; };
class CommOverlapCore { class CommOverlapCore {
...@@ -133,6 +134,11 @@ class CommOverlapCore { ...@@ -133,6 +134,11 @@ class CommOverlapCore {
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented."); NVTE_ERROR("Operation is not implemented.");
} }
virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore }; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
...@@ -200,6 +206,9 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -200,6 +206,9 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override; cudaStream_t stream_main) override;
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override;
}; // CommOverlapBase }; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
...@@ -281,6 +290,15 @@ class CommOverlapP2PBase : public CommOverlapCore { ...@@ -281,6 +290,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override; cudaStream_t stream_main) override;
/*
** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object.
** The gemm for overlap_gemm is assumed to have been previously started.
*/
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
}; // CommOverlapP2PBase }; // CommOverlapP2PBase
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0); int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] alpha Scaling factor applied to the result of the GEMM
* \param[in] beta Scaling factor applied to original value of D when
* accumulating into it. beta=0 means no accumulation.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
* *
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases. * \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
......
...@@ -20,7 +20,6 @@ extern "C" { ...@@ -20,7 +20,6 @@ extern "C" {
/*! \brief Computes L2 norm for a list of tensors. /*! \brief Computes L2 norm for a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -33,22 +32,19 @@ extern "C" { ...@@ -33,22 +32,19 @@ extern "C" {
* \param[out] ret_per_tensor L2 norm for each tensor. * \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor, NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling. /*! \brief Computes L2 norm for a list of tensors after unscaling.
* *
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale, NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor, int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer. /*! \brief Compute and apply gradient update to parameters for Adam optimizer.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
...@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits. * where the master parameters only store the remainder bits.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
* \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_param_remainder_cuda( void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction, const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream); const float weight_decay, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision. * when model parameters are in Float8 precision.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda( ...@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters. * \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype, const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling. * with CUDA graph support and LR scheduling.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_capturable_cuda( void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights. * with CUDA graph support, LR scheduling, and FP32 master weights.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer. /*! \brief Compute and apply gradient update to parameters for SGD optimizer.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
* \param[in] first_run Whether momentum buffers have been initialized. * \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update. * \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation. * \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov, float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale, int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors. /*! \brief Check overflow and scale a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor ...@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
* \param[in] num_tensor_lists Size (dim0) of tensor_lists. * \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists. * \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation. * \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream); float scale, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors. /*! \brief Check overflow and scale a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens ...@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
* \param[in] max_fp8 Maximum representible value in underlying FP8 format. * \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability. * \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, const size_t num_tensor_lists,
const int device_id, cudaStream_t stream); const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -30,6 +30,20 @@ extern "C" { ...@@ -30,6 +30,20 @@ extern "C" {
*/ */
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in ...@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, cudaStream_t stream); NVTETensor output, cudaStream_t stream);
/*! \brief Swap the first two tensor dimensions.
*
* \param[in] input Input tensor of shape [M, N, ...].
* \param[out] output Output tensor of shape [N, M, ...].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*; transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*; transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*; transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*; *transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cudnn_handle_init*;
......
...@@ -577,7 +577,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -577,7 +577,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -644,9 +644,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -644,9 +644,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag, multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id, AdamFunctor<p_in_type, g_in_type, float, int64_t>(), stream,
stream, beta1, beta2, bias_correction1, bias_correction2, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
epsilon, lr, (adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>( multi_tensor_apply<BLOCK_SIZE, 5>(
(int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream, AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)); weight_decay);));
} }
...@@ -667,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -667,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id, AdamFunctor<p_in_type, g_in_type, float, int32_t>(),
stream, beta1, beta2, bias_correction1, bias_correction2, stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
...@@ -678,9 +678,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -678,9 +678,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1, stream, beta1, beta2, bias_correction1, bias_correction2,
bias_correction2, epsilon, lr, (adamMode_t)mode, epsilon, lr, (adamMode_t)mode, weight_decay);));
weight_decay);));
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
...@@ -691,7 +690,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -691,7 +690,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -733,7 +732,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -733,7 +732,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(),
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);); (adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
...@@ -744,7 +743,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -744,7 +743,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype, const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -814,7 +813,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -814,7 +813,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>( multi_tensor_apply<BLOCK_SIZE, 5, true>(
(int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1, AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), stream, beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)); weight_decay);));
} else { } else {
...@@ -824,9 +823,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -824,9 +823,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(), AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1, stream, beta1, beta2, bias_correction1, bias_correction2,
bias_correction2, epsilon, lr, (adamMode_t)mode, epsilon, lr, (adamMode_t)mode, weight_decay);));
weight_decay);));
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -836,7 +834,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -836,7 +834,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction, Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale, const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Check tensor list sizes // Check tensor list sizes
// 4 tensor lists: g, p, m, v // 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
...@@ -868,7 +866,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -868,7 +866,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2, AdamCapturableFunctor<dtype, float>(), stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon, reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));) reinterpret_cast<float *>(inv_scale.data.dptr));)
...@@ -881,8 +879,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -881,8 +879,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
Tensor lr, const float beta1, const float beta2, Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode, const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id, Tensor inv_scale, cudaStream_t stream) {
cudaStream_t stream) {
// Check tensor list sizes // Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master // 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
...@@ -917,7 +914,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -917,7 +914,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1, AdamCapturableMasterFunctor<dtype, float>(), stream, beta1,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction, beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));) weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
...@@ -933,28 +930,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -933,28 +930,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda( multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, stream);
} }
void nvte_multi_tensor_adam_param_remainder_cuda( void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction, const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, stream);
} }
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -963,22 +960,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -963,22 +960,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype, const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda( multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id, epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), stream);
stream);
} }
void nvte_multi_tensor_adam_capturable_cuda( void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -986,14 +982,14 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -986,14 +982,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
} }
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1001,5 +997,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -1001,5 +997,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
} }
...@@ -61,7 +61,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f ...@@ -61,7 +61,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
float epsilon, const int device_id, float epsilon, const int device_id,
cudaStream_t stream) { cudaStream_t stream) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8, ComputeScaleAndScaleInvFunctor(), stream, max_fp8,
force_pow_2_scales, epsilon); force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -69,15 +69,17 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f ...@@ -69,15 +69,17 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
} // namespace multi_tensor_compute_scale } // namespace multi_tensor_compute_scale
} // namespace transformer_engine } // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, const size_t num_tensor_lists,
const int device_id, cudaStream_t stream) { const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda); NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream); force_pow_2_scales, epsilon, stream);
} }
...@@ -401,12 +401,11 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -401,12 +401,11 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output, std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id, bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>( multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(),
stream, reinterpret_cast<float *>(output.data.dptr), stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);) max_chunks_per_tensor);)
...@@ -416,7 +415,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -416,7 +415,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>( cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
...@@ -429,12 +427,11 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -429,12 +427,11 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret, Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor, Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>( multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(),
stream, reinterpret_cast<float *>(inv_scale.data.dptr), stream, reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
...@@ -445,7 +442,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -445,7 +442,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>( cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
...@@ -461,8 +457,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -461,8 +457,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor, NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda); NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -471,7 +466,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -471,7 +466,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor, *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream); max_chunks_per_tensor, stream);
} }
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -480,7 +475,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -480,7 +475,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale, NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor, int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda); NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -489,5 +484,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -489,5 +484,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream);
} }
...@@ -14,53 +14,6 @@ ...@@ -14,53 +14,6 @@
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
...@@ -94,7 +47,7 @@ template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typen ...@@ -94,7 +47,7 @@ template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typen
void multi_tensor_apply(int64_t chunk_size, void multi_tensor_apply(int64_t chunk_size,
const transformer_engine::Tensor &noop_flag, const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists, std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) { T callable, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size(); const size_t num_tensors_per_list = tensor_lists[0].size();
...@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t chunk_size, ...@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t chunk_size,
TensorListMetadata<depth, USE_FP8> tl; TensorListMetadata<depth, USE_FP8> tl;
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
......
...@@ -104,13 +104,13 @@ struct ScaleFunctor { ...@@ -104,13 +104,13 @@ struct ScaleFunctor {
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale, std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type, tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type, tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists,
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);)) ScaleFunctor<p_in_type, g_in_type>(), stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, ...@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) { float scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda); NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda( multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream);
stream);
} }
...@@ -127,8 +127,7 @@ struct SGDFunctor { ...@@ -127,8 +127,7 @@ struct SGDFunctor {
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum, std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id, bool wd_after_momentum, float scale, cudaStream_t stream) {
cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size(); const size_t num_tensors_per_list = tensor_lists[0].size();
...@@ -154,28 +153,28 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, ...@@ -154,28 +153,28 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) { if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening, SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 2. fp32, fp32, fp32, No // Case 2. fp32, fp32, fp32, No
else if (grad_type == DType::kFloat32 && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) { weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<3, float, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == DType::kFloat16 && // NOLINT(*) else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 4. fp32, fp32, fp32, Yes // Case 4. fp32, fp32, fp32, Yes
else if (grad_type == DType::kFloat32 && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<4, float, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} else { } else {
NVTE_ERROR("Unsupported combination of weight and gradient types."); NVTE_ERROR("Unsupported combination of weight and gradient types.");
...@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor ...@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov, float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale, int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda); NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda( multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream);
} }
...@@ -72,6 +72,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -72,6 +72,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif #endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool gamma_in_weight_dtype = false; bool gamma_in_weight_dtype = false;
if (cudnn_backend) { if (cudnn_backend) {
// TODO: add check for GPU ARCH // TODO: add check for GPU ARCH
......
...@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS]; Ivec x[LDGS];
...@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij; compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) { if (requires_amax) {
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output)); amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale; temp_output = temp_output * scale;
} }
...@@ -132,9 +135,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -132,9 +135,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
} }
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
...@@ -142,6 +145,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( ...@@ -142,6 +145,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
} }
} }
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
...@@ -211,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -211,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m; const int row = cta_row + warp_m;
...@@ -279,17 +284,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -279,17 +284,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
} }
// Apply fp8 factors // Apply fp8 factors
if (params.fp8_out) { if (params.fp8_out || requires_amax) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt]; compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij)); amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale; z.data.elt[jt] = z_ij * scale;
} }
} }
} }
}
// Store output // Store output
Ovec z_out; Ovec z_out;
...@@ -298,10 +305,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -298,10 +305,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
} }
} }
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
...@@ -309,6 +314,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -309,6 +314,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
} }
} }
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
...@@ -58,6 +58,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -58,6 +58,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif #endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool training = bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
...@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS]; Ivec x[LDGS];
...@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
} }
compute_t temp_output = g_ij * y_ij; compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) { if (requires_amax) {
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output)); amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale; temp_output = temp_output * scale;
} }
...@@ -124,9 +127,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -124,9 +127,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
} }
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
...@@ -134,6 +137,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -134,6 +137,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
} }
} }
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
...@@ -201,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -201,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m; const int row = cta_row + warp_m;
...@@ -254,17 +259,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -254,17 +259,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
} }
// Apply fp8 factors // Apply fp8 factors
if (params.fp8_out) { if (params.fp8_out || requires_amax) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt]; compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij)); amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale; z.data.elt[jt] = z_ij * scale;
} }
} }
} }
}
// Store output // Store output
Ovec z_out; Ovec z_out;
...@@ -273,10 +280,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -273,10 +280,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
} }
} }
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block // Reduce amax over block
if (params.amax != nullptr) { if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value); static_assert(std::is_same<compute_t, float>::value);
...@@ -284,6 +289,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -284,6 +289,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
} }
} }
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
#include "../util/logging.h" #include "../util/logging.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace { namespace {
constexpr int TB_DIM = 32; constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32;
constexpr int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int TB_DIM = 32;
constexpr int N_SF_PER_TD_PER_TILE = 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks // output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
...@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { ...@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M,
const int K) { const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
...@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
int m_tiles_in_tb = N_TILE_PER_TD; int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM; int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) { if (bid_x == grid_dim_x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
} }
if (blockIdx.y == gridDim.y - 1) { if (bid_y == grid_dim_y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
} }
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
const int input_offset =
bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + input_offset;
int32_t* output_i32[N_TILE_PER_TD]; int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll #pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) { for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + output_i32[i] = reinterpret_cast<int32_t*>(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
} }
extern __shared__ int slm[]; extern __shared__ int slm[];
...@@ -90,85 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -90,85 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
threadIdx.y < k_tiles_in_tb) { threadIdx.y < k_tiles_in_tb) {
#pragma unroll #pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>( const int thread_offset =
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD;
} regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
// Pad zeros
// local shuffle if (padding_m || padding_k) {
regs_shuffle_with_bit_shifts(regs_vec); for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
// store, regs -> shared if (index / M >= original_K || index % M >= original_M) {
int tM = threadIdx.x * N_SF_PER_TD; reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int i = 0; i < N_SF_PER_TD; i++) {
/* TODO rotate_i */
slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
reinterpret_cast<int*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
__align__(16) int4* slm_v4i =
reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
j += blockDim.x * blockDim.y) {
output_v4i[j] = slm_v4i[j];
} }
} }
}
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel_int(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// input is in M-major
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;
const int M_i32 = M / 4;
const int K_i32 = K;
int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
}
if (blockIdx.y == gridDim.y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
} }
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
}
extern __shared__ int slm[];
// load, global -> regs
int regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
threadIdx.y < k_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD);
} }
// local shuffle // local shuffle
...@@ -202,7 +143,14 @@ __global__ void swizzle_col_scaling_kernel_int(const void* input, void* output, ...@@ -202,7 +143,14 @@ __global__ void swizzle_col_scaling_kernel_int(const void* input, void* output,
} }
} }
} }
#endif
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_M, const int original_K) {
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) { __device__ inline void regs_shuffle(LType* regs_vec) {
...@@ -221,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { ...@@ -221,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) {
} }
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M,
const int K) { const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
...@@ -232,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -232,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
int n_tiles_in_tb = N_TILES_IN_TB; int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4; const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) { if (bid_x == grid_dim_x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
} }
const int* input_i32 = reinterpret_cast<const int*>(input) + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB;
const int* input_i32 = reinterpret_cast<const int*>(input) + input_offset;
int* output_i32 = reinterpret_cast<int*>(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 +
bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[]; extern __shared__ int4 slm_v4i[];
...@@ -248,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -248,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll #pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>( const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD;
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
if (padding_m || padding_k) {
// Pad zeros
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
if (index / K >= original_M || index % K >= original_K) {
reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
}
}
}
} }
// shuffle regs // shuffle regs
...@@ -274,64 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -274,64 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
} }
} }
#ifdef __HIP_PLATFORM_AMD__ template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K> __global__ void __launch_bounds__(TB_DIM* TB_DIM)
__global__ void swizzle_row_scaling_kernel_int(const void* input, void* output, const int M, swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K,
const int K) { const int original_M, const int original_K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int); swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
// input is in K-major constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; struct MultiSwizzleArgs {
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; // (input) Data buffers for input scaling factors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for swizzled scaling factors
void* output_list[kMaxTensorsPerKernel];
// Input scaling factor m
int m_list[kMaxTensorsPerKernel];
// Input scaling factor k
int k_list[kMaxTensorsPerKernel];
// Input scaling factor m before padding
int original_m_list[kMaxTensorsPerKernel];
// Input scaling factor k before padding
int original_k_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
int n_tiles_in_tb = N_TILES_IN_TB; template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
const int K_i32 = K / 4; __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) {
if (blockIdx.x == gridDim.x - 1) { // Find tensor corresponding to block
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; const int bid = blockIdx.x;
} int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
const int* input_i32 = reinterpret_cast<const int*>(input) + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[]; // Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB);
const int grid_dim_y = num_tiles_m;
const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
// load, global -> regs swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
int regs_vec[N_SF_PER_TD_PER_TILE]; input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { }
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD);
}
// shuffle regs template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
regs_shuffle<int>(regs_vec); __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) {
// Find tensor corresponding to block
const int bid = blockIdx.x;
int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
// store, regs -> shared constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
#pragma unroll constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
for (int i = 0; i < N_TILE_PER_TD; i++) { constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
/* TODO rotate i */
slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
reinterpret_cast<int4*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global // Get block index in grid. Emulate 2D grid.
int linear_id = threadIdx.y * blockDim.x + threadIdx.x; const int num_tiles_k = K / SF_TILE_DIM_K;
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32); const int num_tiles_m = M / SF_TILE_DIM_M;
#pragma unroll const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM);
for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD);
output_v4i[i] = slm_v4i[i]; const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
} const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
} }
#endif
} // namespace
namespace transformer_engine { } // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
...@@ -385,50 +383,52 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -385,50 +383,52 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
case 4: case 4:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
#else #else
case 4: case 4:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(
output->scale_inv.dptr, m, k); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
#endif #endif
default: default:
...@@ -442,50 +442,58 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -442,50 +442,58 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
case 4: case 4:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
#else #else
case 4: case 4:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 2: case 2:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
case 1: case 1:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break; break;
#endif #endif
default: default:
...@@ -498,10 +506,260 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -498,10 +506,260 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
} else { } else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
} }
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { NVTE_CHECK_CUDA(cudaGetLastError());
printf("CUDA Error: %s\n", cudaGetErrorString(err)); }
exit(-1);
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
const int vec_load_size, const bool is_rowwise,
cudaStream_t stream) {
int n_tiles_in_tb = TB_DIM * vec_load_size;
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
/* Calculate number of CUDA blocks needed for each tensor.
* We have to do it here because we have to iterate over all tensors in this batch to
* get the minimum vec_load_size.
*/
for (size_t j = 0; j < kernel_args.num_tensors; j++) {
const int m = kernel_args.m_list[j];
const int k = kernel_args.k_list[j];
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
if (is_rowwise) {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m;
} else {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] +
DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size);
}
}
// Launch kernel
const int num_blocks = kernel_args.block_range[kernel_args.num_tensors];
dim3 block_size(TB_DIM, TB_DIM);
if (is_rowwise) {
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#else
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
} else {
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#else
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size();
bool all_has_data = true;
bool all_has_columnwise_data = true;
for (size_t i = 0; i < num_tensors; i++) {
if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + ".");
}
// We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data();
}
NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data.");
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
if (all_has_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->scale_inv.shape[0];
const int k = input[i]->scale_inv.shape[1];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
}
if (all_has_columnwise_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->columnwise_scale_inv.shape[1];
const int k = input[i]->columnwise_scale_inv.shape[0];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
...@@ -516,3 +774,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud ...@@ -516,3 +774,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
using namespace transformer_engine; using namespace transformer_engine;
swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
} }
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
std::vector<Tensor*> input_list, output_list;
for (size_t i = 0; i < num_tensors; i++) {
input_list.push_back(convertNVTETensorCheck(inputs[i]));
output_list.push_back(convertNVTETensorCheck(outputs[i]));
}
multi_tensor_swizzle_scaling_factors(input_list, output_list, stream);
}
...@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
} }
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); // Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name); "Scale_inv is not supported for non-FP8 input ", name);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
...@@ -184,6 +185,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -184,6 +185,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c; OVecCast tmp_output_c;
...@@ -419,6 +426,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -419,6 +426,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases // Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here // for full thread tile, it's the same thing here
...@@ -883,6 +896,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -883,6 +896,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1; size_t num_rows = 1;
...@@ -917,9 +932,23 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -917,9 +932,23 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = scale_inv_t.shape[1]; scale_t_stride_y = scale_inv_t.shape[1];
} }
#ifdef __HIP_PLATFORM_AMD__
const size_t block_len = blockwise_fp8_block_len(); const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_x = DIVUP(row_length, block_len); const size_t num_blocks_x = DIVUP(row_length, block_len);
const size_t num_blocks_y = DIVUP(num_rows, block_len); const size_t num_blocks_y = DIVUP(num_rows, block_len);
#else
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -929,10 +958,13 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -929,10 +958,13 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose, return_transpose, kReturnTranspose,
#ifdef __HIP_PLATFORM_AMD__
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
#endif
if (full_tile) { if (full_tile) {
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
CUtensorMap tensor_map_output_trans; CUtensorMap tensor_map_output_trans;
...@@ -940,15 +972,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -940,15 +972,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans = tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length); get_tensor_map<OutputType>(output_t, num_rows, row_length);
} }
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType> cudaLaunchKernelEx(&cfg,
<<<grid, THREADS_PER_BLOCK, 0, stream>>>( block_scaled_cast_transpose_kernel<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
} else {
cudaLaunchKernelEx(
&cfg,
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale); pow_2_scale);
#else #else
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
...@@ -978,7 +1023,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -978,7 +1023,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break; break;
} }
} }
#endif
} else { } else {
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
...@@ -1008,6 +1052,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1008,6 +1052,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break; break;
} }
} }
#endif
} // full-tile } // full-tile
) // return_transpose ) // return_transpose
) // OutputType ) // OutputType
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h" #include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -74,7 +75,7 @@ Step 2: Cast and store to output_c ...@@ -74,7 +75,7 @@ Step 2: Cast and store to output_c
* What each thread does in each loop: * What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times * 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row * Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time * 16 elements are quantized and wsrite to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | | T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
...@@ -251,6 +252,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -251,6 +252,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads(); __syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
...@@ -356,6 +365,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -356,6 +365,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
...@@ -1424,20 +1441,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1424,20 +1441,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = columnwise_compact ? scale_t_k : 1; scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
} }
#ifdef __HIP_PLATFORM_AMD__
const size_t block_len = blockwise_fp8_block_len(); const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_x = DIVUP(row_length, (size_t)block_len); const size_t num_blocks_x = DIVUP(row_length, (size_t)block_len);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)block_len); const size_t num_blocks_y = DIVUP(num_rows, (size_t)block_len);
#else
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType, output.dtype, OutputType,
#ifdef __HIP_PLATFORM_AMD__
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
#endif
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned, full_tile, kAligned,
...@@ -1505,21 +1532,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1505,21 +1532,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
} }
#else #else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>, &block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType> } cudaLaunchKernelEx(&cfg,
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, scale_stride_x, scale_stride_y, scale_t_stride_x,
columnwise_option, pow2_scale); scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
pow2_scale);
#endif #endif
) // kAligned ) // kAligned
) // OutputType ) // OutputType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment