Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
ATOMIC_GEMM_RS_P2P = 7,
EXTERNAL_BULK_OVERLAP_AG = 8,
};
class CommOverlapCore {
......@@ -133,6 +134,11 @@ class CommOverlapCore {
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore {
......@@ -200,6 +206,9 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override;
}; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore {
......@@ -281,6 +290,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
/*
** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object.
** The gemm for overlap_gemm is assumed to have been previously started.
*/
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
}; // CommOverlapP2PBase
} // namespace transformer_engine
......
......@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] alpha Scaling factor applied to the result of the GEMM
* \param[in] beta Scaling factor applied to original value of D when
* accumulating into it. beta=0 means no accumulation.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
......
......@@ -20,7 +20,6 @@ extern "C" {
/*! \brief Computes L2 norm for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -33,22 +32,19 @@ extern "C" {
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
int max_chunks_per_tensor, cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
......@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
const float weight_decay, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
float scale, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -30,6 +30,20 @@ extern "C" {
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, cudaStream_t stream);
/*! \brief Swap the first two tensor dimensions.
*
* \param[in] input Input tensor of shape [M, N, ...].
* \param[out] output Output tensor of shape [N, M, ...].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
......
......@@ -577,7 +577,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
const float weight_decay, cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -644,9 +644,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
......@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(
(int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
......@@ -667,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
AdamFunctor<p_in_type, g_in_type, float, int32_t>(),
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
......@@ -678,9 +678,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
......@@ -691,7 +690,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -733,7 +732,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(),
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError());
......@@ -744,7 +743,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -814,7 +813,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(
(int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), stream, beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else {
......@@ -824,9 +823,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -836,7 +834,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size();
......@@ -868,7 +866,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
AdamCapturableFunctor<dtype, float>(), stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
......@@ -881,8 +879,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id,
cudaStream_t stream) {
Tensor inv_scale, cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
......@@ -917,7 +914,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
AdamCapturableMasterFunctor<dtype, float>(), stream, beta1,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
......@@ -933,28 +930,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
epsilon, step, mode, bias_correction, weight_decay, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
const float weight_decay, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
epsilon, step, mode, bias_correction, weight_decay, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -963,22 +960,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
......@@ -986,14 +982,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
}
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
......@@ -1001,5 +997,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
}
......@@ -61,7 +61,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
float epsilon, const int device_id,
cudaStream_t stream) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
ComputeScaleAndScaleInvFunctor(), stream, max_fp8,
force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -69,15 +69,17 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
force_pow_2_scales, epsilon, stream);
}
......@@ -401,12 +401,11 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(),
stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
......@@ -416,7 +415,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
......@@ -429,12 +427,11 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
int max_chunks_per_tensor, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(),
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
......@@ -445,7 +442,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
......@@ -461,8 +457,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
int max_chunks_per_tensor, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
......@@ -471,7 +466,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
max_chunks_per_tensor, stream);
}
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -480,7 +475,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
......@@ -489,5 +484,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream);
}
......@@ -14,53 +14,6 @@
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
......@@ -94,7 +47,7 @@ template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typen
void multi_tensor_apply(int64_t chunk_size,
const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
T callable, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
......@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t chunk_size,
TensorListMetadata<depth, USE_FP8> tl;
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
......
......@@ -104,13 +104,13 @@ struct ScaleFunctor {
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists,
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
ScaleFunctor<p_in_type, g_in_type>(), stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
float scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream);
}
......@@ -127,8 +127,7 @@ struct SGDFunctor {
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id,
cudaStream_t stream) {
bool wd_after_momentum, float scale, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
......@@ -154,28 +153,28 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 2. fp32, fp32, fp32, No
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
SGDFunctor<3, float, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
SGDFunctor<4, float, float>(), stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
} else {
NVTE_ERROR("Unsupported combination of weight and gradient types.");
......@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream);
}
......@@ -72,6 +72,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
......
......@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
......@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
if (params.fp8_out) {
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale;
}
......@@ -132,16 +135,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
// Reduce amax over block
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......@@ -211,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
......@@ -279,14 +284,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
// Apply fp8 factors
if (params.fp8_out) {
if (params.fp8_out || requires_amax) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
}
}
}
......@@ -298,17 +305,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
// Reduce amax over block
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
......@@ -58,6 +58,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
}
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
......@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
......@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
}
if (params.fp8_out) {
temp_output = temp_output * scale;
}
......@@ -124,16 +127,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
idx += VEC_COLS_PER_LDG;
}
}
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
// Reduce amax over block
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......@@ -201,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
......@@ -254,14 +259,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
// Apply fp8 factors
if (params.fp8_out) {
if (params.fp8_out || requires_amax) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
}
}
}
......@@ -273,17 +280,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
// Reduce amax over block
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
......@@ -15,15 +15,17 @@
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace {
constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;
constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32;
constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32;
template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
......@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M,
const int K) {
__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
......@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) {
if (bid_x == grid_dim_x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
}
if (blockIdx.y == gridDim.y - 1) {
if (bid_y == grid_dim_y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
}
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
const int input_offset =
bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) + input_offset;
int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
output_i32[i] = reinterpret_cast<int32_t*>(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 +
(bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
}
extern __shared__ int slm[];
......@@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
threadIdx.y < k_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD));
const int thread_offset =
(threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD;
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
// Pad zeros
if (padding_m || padding_k) {
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
if (index / M >= original_K || index % M >= original_M) {
reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
}
}
}
}
// local shuffle
......@@ -126,83 +144,13 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
}
}
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel_int(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// input is in M-major
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;
const int M_i32 = M / 4;
const int K_i32 = K;
int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
}
if (blockIdx.y == gridDim.y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
}
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
}
extern __shared__ int slm[];
// load, global -> regs
int regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
threadIdx.y < k_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD);
}
// local shuffle
regs_shuffle_with_bit_shifts(regs_vec);
// store, regs -> shared
int tM = threadIdx.x * N_SF_PER_TD;
int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int i = 0; i < N_SF_PER_TD; i++) {
/* TODO rotate_i */
slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
reinterpret_cast<int*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
__align__(16) int4* slm_v4i =
reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
j += blockDim.x * blockDim.y) {
output_v4i[j] = slm_v4i[j];
}
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_M, const int original_K) {
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
#endif
template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) {
......@@ -221,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) {
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M,
const int K) {
__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M,
const int K, const int original_M,
const int original_K, const int bid_x,
const int bid_y, const int grid_dim_x,
const int grid_dim_y) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
......@@ -232,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) {
if (bid_x == grid_dim_x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
}
const int* input_i32 = reinterpret_cast<const int*>(input) +
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M);
bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K);
const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB;
const int* input_i32 = reinterpret_cast<const int*>(input) + input_offset;
int* output_i32 = reinterpret_cast<int*>(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 +
bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[];
......@@ -248,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD));
const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD;
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(input_i32 + thread_offset));
if (padding_m || padding_k) {
// Pad zeros
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
const int index = (input_offset + thread_offset) * sizeof(int) + j;
if (index / K >= original_M || index % K >= original_K) {
reinterpret_cast<uint8_t*>(regs_vec + i)[j] = 0;
}
}
}
}
// shuffle regs
......@@ -274,64 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
}
}
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel_int(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K,
const int original_M, const int original_K) {
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y);
}
// input is in K-major
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiSwizzleArgs {
// (input) Data buffers for input scaling factors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for swizzled scaling factors
void* output_list[kMaxTensorsPerKernel];
// Input scaling factor m
int m_list[kMaxTensorsPerKernel];
// Input scaling factor k
int k_list[kMaxTensorsPerKernel];
// Input scaling factor m before padding
int original_m_list[kMaxTensorsPerKernel];
// Input scaling factor k before padding
int original_k_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) {
// Find tensor corresponding to block
const int bid = blockIdx.x;
int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
const int* input_i32 = reinterpret_cast<const int*>(input) +
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
// load, global -> regs
int regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD);
}
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB);
const int grid_dim_y = num_tiles_m;
const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
// shuffle regs
regs_shuffle<int>(regs_vec);
swizzle_row_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
}
// store, regs -> shared
#pragma unroll
for (int i = 0; i < N_TILE_PER_TD; i++) {
/* TODO rotate i */
slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
reinterpret_cast<int4*>(regs_vec)[i];
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) {
// Find tensor corresponding to block
const int bid = blockIdx.x;
int tensor_id = 0;
while (kernel_args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
__syncthreads();
// Get args corresponding to block
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];
const int original_M = kernel_args.original_m_list[tensor_id];
const int original_K = kernel_args.original_k_list[tensor_id];
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
output_v4i[i] = slm_v4i[i];
}
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM);
const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD);
const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y;
const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y;
swizzle_col_scaling_kernel_impl<LType, SF_TILE_DIM_M, SF_TILE_DIM_K>(
input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y);
}
#endif
} // namespace
namespace transformer_engine {
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
......@@ -385,50 +383,52 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
#else
case 4:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
#endif
default:
......@@ -442,50 +442,58 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
#else
case 4:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
#endif
default:
......@@ -498,10 +506,260 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
} else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
exit(-1);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
const int vec_load_size, const bool is_rowwise,
cudaStream_t stream) {
int n_tiles_in_tb = TB_DIM * vec_load_size;
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
/* Calculate number of CUDA blocks needed for each tensor.
* We have to do it here because we have to iterate over all tensors in this batch to
* get the minimum vec_load_size.
*/
for (size_t j = 0; j < kernel_args.num_tensors; j++) {
const int m = kernel_args.m_list[j];
const int k = kernel_args.k_list[j];
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
if (is_rowwise) {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m;
} else {
kernel_args.block_range[j + 1] =
kernel_args.block_range[j] +
DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size);
}
}
// Launch kernel
const int num_blocks = kernel_args.block_range[kernel_args.num_tensors];
dim3 block_size(TB_DIM, TB_DIM);
if (is_rowwise) {
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#else
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
} else {
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
(const void *)multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#else
case 4:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size();
bool all_has_data = true;
bool all_has_columnwise_data = true;
for (size_t i = 0; i < num_tensors; i++) {
if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + ".");
}
// We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data();
}
NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data.");
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
if (all_has_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->scale_inv.shape[0];
const int k = input[i]->scale_inv.shape[1];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, true, stream);
}
if (all_has_columnwise_data) {
MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
//Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
// Reset the argument struct and vec_load_size
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->columnwise_scale_inv.shape[1];
const int k = input[i]->columnwise_scale_inv.shape[0];
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE;
kernel_args.num_tensors++;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_swizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
}
}
} // namespace transformer_engine
......@@ -516,3 +774,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
using namespace transformer_engine;
swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
std::vector<Tensor*> input_list, output_list;
for (size_t i = 0; i < num_tensors; i++) {
input_list.push_back(convertNVTETensorCheck(inputs[i]));
output_list.push_back(convertNVTETensorCheck(outputs[i]));
}
multi_tensor_swizzle_scaling_factors(input_list, output_list, stream);
}
......@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
// Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
......
......@@ -18,6 +18,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
......@@ -184,6 +185,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c;
......@@ -419,6 +426,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
......@@ -883,6 +896,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1;
......@@ -917,9 +932,23 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = scale_inv_t.shape[1];
}
#ifdef __HIP_PLATFORM_AMD__
const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_x = DIVUP(row_length, block_len);
const size_t num_blocks_y = DIVUP(num_rows, block_len);
#else
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
......@@ -929,10 +958,13 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
#ifdef __HIP_PLATFORM_AMD__
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
#endif
if (full_tile) {
#ifndef __HIP_PLATFORM_AMD__
CUtensorMap tensor_map_output_trans;
......@@ -940,15 +972,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length);
}
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale);
cudaLaunchKernelEx(&cfg,
block_scaled_cast_transpose_kernel<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
} else {
cudaLaunchKernelEx(
&cfg,
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale);
#else
while (true) {
if (128 == block_len) {
......@@ -978,7 +1023,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break;
}
}
#endif
} else {
while (true) {
if (128 == block_len) {
......@@ -1008,6 +1052,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break;
}
}
#endif
} // full-tile
) // return_transpose
) // OutputType
......
......@@ -24,6 +24,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace transformer_engine {
......@@ -74,7 +75,7 @@ Step 2: Cast and store to output_c
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
* 16 elements are quantized and wsrite to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
......@@ -251,6 +252,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
......@@ -356,6 +365,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) {
constexpr int c_stride =
......@@ -1424,20 +1441,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
}
#ifdef __HIP_PLATFORM_AMD__
const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_x = DIVUP(row_length, (size_t)block_len);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)block_len);
#else
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType,
#ifdef __HIP_PLATFORM_AMD__
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
#endif
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
......@@ -1505,25 +1532,34 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
}
#else
size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);
} cudaLaunchKernelEx(&cfg,
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType>,
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
pow2_scale);
#endif
) // kAligned
) // OutputType
) // InputType
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment