Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -133,21 +133,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__
void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
......
......@@ -24,7 +24,7 @@ extern "C" {
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* @f]
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
......@@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
* else
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* but instead set the shape and type of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
......@@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
......@@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
......@@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient
*
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] add Additional tensor to add to output gradient [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] workspace Workspace tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETensor add,
const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx,
NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Helper to enable cuDNN backend for normalization
*
* \param[in] bool Enable if True
......
......@@ -86,6 +86,21 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_channel_colwise_amax(const NVTETensor input, NVTETensor output, const NVTETensor fp8_scale, cudaStream_t stream);
/*! \brief Compute an FP8 tensor's amax with quantization config.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor, using the provided
* quantization configuration.
* One useful config is the noop tensor, which is needed by cuda graph.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[in] config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig config, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
......
......@@ -421,6 +421,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
......@@ -448,6 +449,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_l2norm
......
......@@ -140,8 +140,8 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
if (_launch_params.barrier_bytes > 0) {
_launch_params.params.barrier =
reinterpret_cast<int*>(workspace_dptr + _launch_params.workspace_bytes);
cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes,
_launch_params.stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(_launch_params.params.barrier, 0,
_launch_params.barrier_bytes, _launch_params.stream));
}
if constexpr (std::is_same_v<KernelParamsType, BackwardKernelParams>) {
_launch_params.params.dgamma_part =
......@@ -158,7 +158,7 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dx_dptr, void* dz_dptr, void* add_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Forward normalization should not call the backward execute function!");
......@@ -168,8 +168,9 @@ template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
void* add_dptr, void* dbeta_dptr,
void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
_launch_params.stream = stream;
auto& kernel_params = _launch_params.params;
......@@ -179,6 +180,7 @@ void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamm
kernel_params.rs = rsigma_dptr;
kernel_params.dx = dx_dptr;
kernel_params.dz = dz_dptr;
kernel_params.add = add_dptr;
kernel_params.dgamma = dgamma_dptr;
if (_is_layernorm) {
......@@ -467,11 +469,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
#ifdef USE_ROCM
assert(false);
#else
// cuDNN does not currently support fused backward+add
NVTE_CHECK(add_dptr == nullptr);
// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
......
......@@ -130,6 +130,9 @@ struct BackwardKernelParams : public KernelParamsBase {
// Input: gradient wrt. LN FWD output.
void* dz;
// Input: extra tensor to add for fused backward+add
void* add;
// Workspace for Wgrad pre-reduction.
void* dbeta_part;
void* dgamma_part;
......@@ -141,8 +144,10 @@ struct BackwardKernelParams : public KernelParamsBase {
void* dgamma;
};
using BackwardAddKernelParams = BackwardKernelParams;
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Stage { Forward, Backward };
enum class NVTE_Norm_Stage { Forward, Backward, BackwardAdd };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
struct TupleHash {
......@@ -225,8 +230,8 @@ class NormalizationPlanBase {
cudaStream_t stream) = 0;
virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) = 0;
void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr,
void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) = 0;
private:
virtual void _build() = 0;
......@@ -245,8 +250,8 @@ class TeNormalizationPlan : public NormalizationPlanBase {
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) override;
private:
void _set_workspace();
......@@ -274,8 +279,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) override;
private:
void _build() override;
......
......@@ -73,7 +73,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
NVTE_CHECK(!cudnn_backend,
"cuDNN does not currently support amax output for non quantized output");
}
bool gamma_in_weight_dtype = false;
......@@ -192,7 +193,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream);
dz.data.dptr, nullptr /*add*/, dbeta->data.dptr, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
......
......@@ -14,7 +14,7 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
void launch_ln_bwd_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
......@@ -22,8 +22,8 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
......@@ -57,13 +57,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES, stream));
}
using Kernel_traits_f =
......@@ -74,12 +75,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
auto kernel_f = &ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void launch_ln_bwd_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
......@@ -95,8 +97,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
......@@ -117,10 +119,11 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream));
}
// Launch finalization kernel
......@@ -134,6 +137,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
}
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
......@@ -142,8 +146,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
launch_ln_bwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
__VA_ARGS__>(launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
......@@ -13,15 +13,15 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
void launch_ln_fwd_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
......@@ -53,18 +53,20 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES_FWD, stream));
}
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void launch_ln_fwd_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
......@@ -78,8 +80,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
......@@ -99,10 +101,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream));
}
}
......@@ -112,8 +115,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
launch_params, configure_params); \
launch_ln_fwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
__VA_ARGS__>(launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
......@@ -59,7 +59,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output
NVTE_CHECK(!cudnn_backend,
"cuDNN does not currently support amax output for non quantized output");
}
bool training =
......@@ -169,7 +170,74 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream);
dz.data.dptr, nullptr /*add*/, nullptr /*dbeta*/, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const Tensor &rsigma,
const Tensor &gamma, Tensor *dx, Tensor *dgamma, Tensor *workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(dz.data.dtype == gamma.data.dtype);
NVTE_CHECK(add.data.dtype == gamma.data.dtype);
NVTE_CHECK(rsigma.data.dtype == DType::kFloat32);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
NVTE_CHECK(add.data.shape == x.data.shape);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(add, "add");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
}
// cuDNN does not currently support fused backward+add
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
"zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add");
bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
bool gamma_in_weight_dtype = false;
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd,
gamma.data.dtype, // wtype
x.data.dtype, // itype
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, add.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
......@@ -202,3 +270,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
void nvte_rmsnorm_bwd_add(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor add, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd_add);
using namespace transformer_engine;
rmsnorm_bwd_add(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(add), *convertNVTETensorCheck(rsigma),
*convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma),
convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream);
}
......@@ -7,13 +7,31 @@
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#include <type_traits>
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace normalization {
template <typename Ktraits>
struct maybe_not_t {};
template <typename T, bool Enabled>
using maybe_t = std::conditional_t<Enabled, T, maybe_not_t>;
template <typename Ivec, typename Ovec, bool FusedAdd>
union dx_add_t {
using add_t = maybe_t<Ovec, FusedAdd>;
using dx_t = Ivec;
struct {
char _padding[sizeof(dx_t) > sizeof(add_t) ? sizeof(dx_t) - sizeof(add_t) : 0];
add_t add;
};
dx_t dx;
};
template <typename Ktraits, bool FusedAdd>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BackwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
......@@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
}
}
dx_add_t<Ivec, Ovec, FusedAdd> temp[LDGS];
if constexpr (FusedAdd) {
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
temp[it].add.load_from(params.add, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
}
reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
......@@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp;
if constexpr (FusedAdd) {
compute_t add_tmp = temp[it].add.data.elt[jt];
dx_tmp += add_tmp;
}
dx[it].store_to(params.dx, idx);
temp[it].dx.data.elt[jt] = dx_tmp;
}
temp[it].dx.store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
......@@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi
}
}
template <typename Ktraits>
template <typename Ktraits, bool FusedAdd>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BackwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
......@@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
dx_add_t<Ivec, Ovec, FusedAdd> temp;
if constexpr (FusedAdd) {
temp.add.load_from_elts(params.add, row * params.cols + col, params.cols - col);
}
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
compute_t dx_ij = rs * (dy_ij - (mdyy * y_ij));
if constexpr (FusedAdd) {
compute_t add_ij = temp.add.data.elt[jt];
dx_ij += add_ij;
}
temp.dx.data.elt[jt] = dx_ij;
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
temp.dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
......
......@@ -12,18 +12,17 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_rmsnorm_bwd_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<BackwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits, FUSED_ADD>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
int ctas_per_sm = 0;
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
......@@ -37,19 +36,17 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}
#ifndef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#else
#else
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
#endif
#endif
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
......@@ -57,13 +54,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES, stream));
}
using Kernel_traits_f =
......@@ -74,20 +72,20 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_rmsnorm_bwd_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<BackwardKernelParams>& launch_params = *plaunch_params;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits, FUSED_ADD>;
// Configure kernel params
const int rows = launch_params.params.rows;
......@@ -95,9 +93,9 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
int ctas_per_sm = 0;
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
......@@ -120,10 +118,11 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream));
}
// Launch finalization kernel
......@@ -137,6 +136,7 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
}
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
......@@ -145,15 +145,15 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
&launch_params, configure_params); \
launch_rmsnorm_bwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
__VA_ARGS__>(launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// Create rmsnorm bwd tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......@@ -181,7 +181,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature:
// Create rmsnorm bwd general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......@@ -214,3 +214,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32,
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
// Create fused rmsnorm bwd + add tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
// Create fused rmsnorm bwd + add general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4,
true);
......@@ -13,17 +13,16 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<ForwardKernelParams>* plaunch_params,
void launch_rmsnorm_fwd_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<ForwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
......@@ -55,20 +54,21 @@ void launch_tuned_(LaunchParams<ForwardKernelParams>* plaunch_params,
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES_FWD, stream));
}
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
void launch_rmsnorm_fwd_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
LaunchParams<ForwardKernelParams>& launch_params = *plaunch_params;
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
......@@ -81,8 +81,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
......@@ -102,10 +102,11 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream));
}
}
......@@ -115,8 +116,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \
&launch_params, configure_params); \
launch_rmsnorm_fwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
__VA_ARGS__>(launch_params, configure_params); \
} \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
......@@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t
switch (wait_kind) {
case WaitKind::KERNEL_WAIT:
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case WaitKind::NVSHMEM_WAIT:
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
(cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT));
break;
case WaitKind::STREAM_WAIT:
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
CU_STREAM_WAIT_VALUE_GEQ);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
NVTE_CHECK_CUDA_DRIVER(cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
(cuuint64_t)wait_value, CU_STREAM_WAIT_VALUE_GEQ));
NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
(cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT));
break;
}
}
......@@ -251,6 +251,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK,
num_out_tokens);
NVTE_CHECK_CUDA(cudaGetLastError());
blocks = num_rows;
#ifdef __HIP_PLATFORM_AMD__
......@@ -260,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
#endif
moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>(
input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
// moe_unpermute_bwd
......@@ -271,6 +273,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_kernel<T, TCompute, 1, false><<<blocks, threads, 0, stream>>>(
input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
// moe_unpermute_bwd with probs
......@@ -294,6 +297,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
} else {
NVTE_ERROR("topK cannot exceed 128.");
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
}
......@@ -322,11 +326,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
moe_unpermute_kernel<T, TCompute, false><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, nullptr, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
// moe_unpermute_fwd with probs
moe_unpermute_kernel<T, TCompute, true><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, prob, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......
......@@ -30,7 +30,11 @@ constexpr int amax_kernel_threads = 512;
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
const size_t num_aligned_elements, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = 0.f;
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -124,9 +128,10 @@ void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) {
// Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, sizeof(float), stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
// Return immediately if tensor is empty
if (N == 0) {
......@@ -147,16 +152,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
switch (align) {
case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break;
case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
amax_kernel<1, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, N, noop_ptr);
break;
}
}
......@@ -188,8 +194,10 @@ void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, con
} // namespace
} // namespace transformer_engine
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
namespace {
void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
const NVTEQuantizationConfig config_) {
using namespace transformer_engine;
// Check input tensor
......@@ -224,12 +232,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true);
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
stream);); // NOLINT(*)
noop_ptr, stream);); // NOLINT(*)
}
} // anonymous namespace
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
compute_amax_impl(input_, output_, stream, nullptr);
}
void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_,
const NVTEQuantizationConfig config_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax_with_config);
compute_amax_impl(input_, output_, stream, config_);
}
void nvte_compute_channel_colwise_amax(const NVTETensor input_, const NVTETensor output_, const NVTETensor fp8_scale_, cudaStream_t stream) {
......@@ -271,7 +302,11 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales,
const float epsilon) {
const float epsilon, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
std::numeric_limits<float>::max());
}
......@@ -317,10 +352,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;);
// noop tensor for cuda graph
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Update scale
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
config.amax_epsilon, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -373,6 +373,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
break;
}
})
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
......@@ -420,6 +421,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
break;
}
})))
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace fp8_block_scaling_recipe
......
......@@ -410,22 +410,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
break;
#else
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
......@@ -435,6 +438,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
if (input->has_columnwise_data()) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
......@@ -472,24 +476,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
break;
#else
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
......@@ -500,6 +507,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// 2D block scaling
......@@ -563,23 +571,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
break;
#else
case 4:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
......@@ -614,23 +622,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
break;
#else
case 4:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 2:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
cudaFuncSetAttribute(
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
......
......@@ -560,11 +560,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
}
// Set amax to 0 if allocated
if (t.amax.dptr != nullptr) {
cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment