Unverified Commit 0e80c847 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Split cast/gated kernels by scaling mode (#2248)



* Separated gated and dequantize kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Separated quantize, dequantize and gated functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed persistent lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added missing compute capability 10.0 check for Quantize FP8 TMA kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed the issue which was added again by autofix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Removed unsupported template arguments in NVFP4 quantize
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed undefined symbol error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed condition
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Fixed CUDA version check
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed arch conditions order
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the PR review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Split quantize helper into two (FWD and BWD) functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Enabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Disabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 490a5f41
......@@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
......@@ -336,8 +336,7 @@ option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --u
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu)
activation/swiglu.cu)
endif()
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
......
......@@ -14,26 +14,17 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
......@@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
......@@ -63,8 +51,7 @@ template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
......
......@@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -10,36 +10,20 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "dispatch/dequantize.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/transpose.h"
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
......@@ -59,15 +43,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
......@@ -77,87 +54,17 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
stream);
}
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
......@@ -166,12 +73,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
......@@ -184,9 +86,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file common.cuh
* \brief Common functions in quantize.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace common {
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
return isFullTile;
}
inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) {
return;
}
const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace common
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize.cuh
* \brief Dequantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
fp8::dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
mxfp8::dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
nvfp4::dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated.cuh
* \brief Gated dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
namespace transformer_engine {
namespace dispatch {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
const Tensor input = *convertNVTETensorCheck(nvte_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim() / 2;
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols,
"Wrong output shape. Expected (after flattening) [*, ", cols, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
} else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
Tensor dummy_grad_tensor;
mxfp8::quantize_gated</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
const Tensor &grad = *(convertNVTETensorCheck(nvte_grad));
const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(grad, "grad");
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ",
gated_input.flat_last_dim(), ".");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(grad.flat_last_dim() == cols,
"Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [",
rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape,
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape,
", output shape: ", output->data.shape, ".");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
} else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
mxfp8::quantize_gated</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize.cuh
* \brief Quantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_ACT) {
cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, float, ParamOP, OP>(
*input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
mxfp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output,
NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *grad_tensor = convertNVTETensorCheck(grad);
const Tensor *input_tensor = convertNVTETensor(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT) {
cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, float, ParamOP, OP>(
*grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*grad_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = grad_tensor->flat_first_dim();
int32_t cols = grad_tensor->flat_last_dim();
auto dtype = grad_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/grad_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_fp8.cuh
* \brief CUDA kernels to dequantize from FP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
struct DequantizeParam {
const float *scale_inv;
};
__device__ inline float dequantize_func(float value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
}
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, DequantizeParam, dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated_fp8.cuh
* \brief CUDA kernels to cast to FP8 with gated activations.
*/
#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_
#define TRANSFORMER_ENGINE_GATED_FP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
namespace kernel {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X;
constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t BUFFER_DIM_Y = 32;
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
static_assert(ITERATIONS >= 1);
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0;
constexpr size_t in_act_mem = buff_size_aligned_in;
constexpr size_t in_gate_mem = buff_size_aligned_in;
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act = reinterpret_cast<const uint64_t *>(&tensor_map_output_act);
const uint64_t *TMAP_output_gate = reinterpret_cast<const uint64_t *>(&tensor_map_output_gate);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh,
TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate,
chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
} else {
copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh,
TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y,
in_transaction_size, &mbar[next_it], is_master_thread);
} else {
copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x,
chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, in_transaction_size,
&mbar[next_it], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_sh_curr));
if constexpr (IS_BWD) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_sh_curr));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace kernel
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
using namespace kernel;
checkCuDriverContext(stream);
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act{};
alignas(64) CUtensorMap tensor_map_output_gate{};
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
auto kernel = cast_fp8_gated_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_size));
kernel<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP, DActOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_fp8.cuh
* \brief CUDA kernels to quantize to FP8.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
#include "../core/common.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
namespace quantize_2D_kernel {
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128;
constexpr size_t FP8_BUFFERS_NUM = 2;
constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1;
static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM);
constexpr size_t FP8_BUFFER_DIM_Y = 16;
constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128
constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128
constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM);
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType>
__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output,
float *const dbias_workspace, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows,
const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
const size_t dbias_offset_Y = blockIdx.y + tid_Y;
const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols;
const size_t dbias_stride = cols;
float partial_dbias = 0.f;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS];
initialize_barriers<FP8_ITERATIONS, FP8_THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
const size_t chunk_offset_Y = block_offset_Y;
const size_t chunk_offset_X = block_offset_X;
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const size_t chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size,
&mbar[prefetch_buff], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff],
is_master_thread);
}
}
#pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const size_t buff = iter % FP8_BUFFERS_NUM;
const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) {
const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
}
}
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
float elt = static_cast<float>(in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
if constexpr (IS_DACT) {
if (!out_of_bounds) {
partial_dbias += elt;
}
} else {
// If no activation, elt is 0 so we can safely do this
partial_dbias += elt;
}
}
__builtin_assume(amax >= 0);
if (IS_DACT) {
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
amax = fmaxf(amax, fabsf(elt));
}
out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<FP8_PREFETCH_BUFFERS_NUM>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
if constexpr (IS_DBIAS) {
const size_t dbias_offset_X = my_column;
const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias;
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<FP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<FP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace quantize_2D_kernel
namespace quantize_1D_kernel {
using namespace quantize_2D_kernel;
constexpr size_t CHUNKS_PER_BLOCK = 128;
constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK;
constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK;
constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE;
constexpr size_t CHUNKS_PER_ITERATION = 32;
constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE;
constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION;
constexpr size_t SHMEM_BUFFERS = 2;
static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0);
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &), typename IType,
typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_BLOCK>(mbar, is_master_thread);
int parity = 0;
copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread);
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const size_t buff = iter % SHMEM_BUFFERS;
const size_t it_offset = iter * SHMEM_DIM;
const size_t next_iter = iter + 1;
const size_t next_buff = next_iter % SHMEM_BUFFERS;
const size_t next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
&(mbar[next_iter]), is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(elt));
out_sh[buff][shmem_offset] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
ptx::cp_async_bulk_tensor_1d_shared_to_global(
reinterpret_cast<uint64_t *>(output + it_offset),
reinterpret_cast<uint64_t *>(&out_sh[buff]), transaction_size_OUT);
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_BLOCK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace quantize_1D_kernel
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
using namespace quantize_1D_kernel;
const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
NVTE_CHECK(isFullTile, "Only full tiles are supported.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
const size_t chunks = DIVUP(N, CHUNK_SIZE);
const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
const float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(THREADS_PER_BLOCK);
const dim3 grid(blocks);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
const IType *input_ptr = reinterpret_cast<const IType *>(input.data.dptr);
OType *output_ptr = reinterpret_cast<OType *>(output->data.dptr);
cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>(
input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*)
); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) {
using namespace quantize_2D_kernel;
checkCuDriverContext(stream);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
const size_t blocks_Y = chunks_Y;
const size_t blocks_X = chunks_X;
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(FP8_THREADS_PER_CHUNK);
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
}
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);
NVTE_CHECK_CUDA(cudaGetLastError());
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail {
using Empty = transformer_engine::Empty;
__device__ inline float identity(float value, const Empty &) { return value; }
} // namespace detail
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
cudaStream_t stream) {
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const fp32 *>(noop->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
cudaStream_t stream) {
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output,
Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace quantize_1D_kernel;
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias != nullptr);
CheckOutputTensor(*dbias, "dbias");
}
if constexpr (IS_DACT) {
NVTE_CHECK(act_input != nullptr);
CheckInputTensor(*act_input, "activation_input");
NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match.");
}
NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
// Supported by the Arch >= 10.0
if (is_supported_by_CC_100()) {
if (!IS_DBIAS && !IS_DACT) {
if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8
quantize_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else {
// Unaligned
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
} else if (!IS_DBIAS && IS_DACT) {
if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 (+dAct)
quantize_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
} else {
// Unaligned
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
} else {
quantize_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
}
} else {
if (IS_DBIAS) {
// zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" or IS_DBIAS=true" + " on GPU with compute capability < 10.0.");
}
if (!IS_DACT) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
} else {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
}
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
......@@ -4,36 +4,27 @@
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_kernels.cuh
* \brief CUDA kernels to cast from MXFP8.
/*! \file dequantize_mxfp8.cuh
* \brief CUDA kernels to dequantize from MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include <transformer_engine/transformer_engine.h>
namespace transformer_engine {
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
namespace dequantization {
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace dequantize_kernel {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
......@@ -228,29 +219,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace dequantize_kernel
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
using namespace dequantize_kernel;
bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream);
......@@ -334,113 +306,8 @@ void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization
namespace detail {
void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
dequantization::fp8_dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace detail
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
......@@ -4,280 +4,27 @@
* See LICENSE for license information.
************************************************************************/
/*! \file cast_gated_kernels.cuh
* \brief CUDA gated activations kernels to cast to/from FP8/MXFP8.
/*! \file gated_mxfp8.cuh
* \brief CUDA kernels to cast to MXFP8 with gated activations.
*/
#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#ifndef TRANSFORMER_ENGINE_GATED_MXFP8_CUH_
#define TRANSFORMER_ENGINE_GATED_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include <cfloat>
#include "../common.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
namespace transformer_engine {
namespace gated_kernels {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X;
constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t BUFFER_DIM_Y = 32;
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
static_assert(ITERATIONS >= 1);
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0;
constexpr size_t in_act_mem = buff_size_aligned_in;
constexpr size_t in_gate_mem = buff_size_aligned_in;
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act = reinterpret_cast<const uint64_t *>(&tensor_map_output_act);
const uint64_t *TMAP_output_gate = reinterpret_cast<const uint64_t *>(&tensor_map_output_gate);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh,
TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate,
chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
} else {
copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh,
TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y,
in_transaction_size, &mbar[next_it], is_master_thread);
} else {
copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x,
chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, in_transaction_size,
&mbar[next_it], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_sh_curr));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_sh_curr));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
namespace mxfp8_kernel {
namespace dispatch {
namespace mxfp8 {
namespace gated_kernel {
constexpr size_t CHUNK_DIM_Y = 64;
constexpr size_t CHUNK_DIM_X = 64;
......@@ -302,20 +49,21 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_act_colwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise, const ParamOP p) {
quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_act_colwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols,
const size_t scale_stride_rowwise,
const size_t scale_stride_colwise, const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
......@@ -385,14 +133,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
const size_t out_mem = out_act_mem + out_gate_mem;
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
......@@ -427,7 +175,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
int parity = 0;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y,
&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
&in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
......@@ -454,7 +202,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset],
......@@ -497,7 +245,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt;
float act_x;
......@@ -526,20 +274,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
after_act_colwise[i] = after_act_elt;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
after_gate_colwise[i] = after_gate_elt;
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(after_act_elt);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
cached_gate_sh[shmem_offset_colwise] = static_cast<IType>(after_gate_elt);
}
}
......@@ -549,7 +297,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
......@@ -578,7 +326,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// All threads read the reduced amax (ACT)
thread_amax_act = subamax_colwise_buff[0][tid_X_colwise];
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
// Make sure the previous read of the ACT values has been completed,
// so the data are not rewritten
__syncthreads();
......@@ -622,7 +370,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
float block_scale_inverse_gate;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
......@@ -639,7 +387,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const size_t shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
OType2 out_pair;
ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]};
const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act,
......@@ -685,7 +433,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Load cached elements
in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]);
}
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
......@@ -695,7 +443,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e]));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e]));
}
}
......@@ -705,7 +453,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e],
in_cached_act[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e],
in_cached_gate[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate);
......@@ -717,7 +465,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (!std::is_same_v<IType, float>) {
thread_amax_act = static_cast<float>(
__hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y)));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
thread_amax_gate = static_cast<float>(
__hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y)));
}
......@@ -735,7 +483,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
in_act.load_from(&in_act_sh[shmem_offset_rowwise]);
in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]);
}
......@@ -753,7 +501,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt;
float act_x;
......@@ -786,7 +534,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
......@@ -796,7 +544,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
......@@ -822,7 +570,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float block_scale_inverse_gate;
ptx::floatx2 block_scale_inverse_2x_gate;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
......@@ -853,7 +601,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
IType2 in_gate;
OType2 &out_gate_pair = reinterpret_cast<OType2 &>(out_gate.data.elt[e]);
......@@ -873,7 +621,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
}
}
......@@ -894,7 +642,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_rowwise_sh[buff_offset]));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_rowwise_sh[buff_offset]));
......@@ -904,7 +652,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_colwise_sh[buff_offset]));
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_colwise_sh[buff_offset]));
......@@ -920,94 +668,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace mxfp8_kernel
} // namespace gated_kernel
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
checkCuDriverContext(stream);
if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act{};
alignas(64) CUtensorMap tensor_map_output_gate{};
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
using namespace gated_kernel;
checkCuDriverContext(stream);
const bool USE_ROWWISE_SCALING = output->has_data();
......@@ -1031,17 +698,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y;
constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X;
constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X);
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE;
const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE)
? THREADS_PER_CHUNK_COLWISE
: THREADS_PER_CHUNK_NON_COLWISE;
......@@ -1073,7 +734,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_DGATED) {
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
}
......@@ -1110,238 +771,68 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE:
case ScalingType::ROWWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true,
false, THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
case ScalingType::COLWISE:
}
case ScalingType::COLWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, false,
true, THREADS_PER_CHUNK_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true,
THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
case ScalingType::BIDIMENSIONAL:
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true,
true, THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2,
"Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2,
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2,
"Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2,
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match. Input shape: ", input.data.shape,
", output shape: ", output->data.shape, ".");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP, DActOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", allow_empty);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even.");
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
if constexpr (IS_DGATED) {
CheckInputTensor(grad, "grad");
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input.");
NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input.");
}
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
bool is_fp8_rowwise_output = true;
bool is_fp8_colwise_output = true;
if (output->has_data()) {
is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype);
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
}
if (output->has_columnwise_data()) {
is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype);
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
}
const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0;
if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input, output, p, stream);
}
}
} else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape);
}
} else {
NVTE_ERROR("Not supported scaling mode");
}
}
} // namespace gated_kernels
namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input);
Tensor *output_tensor = convertNVTETensorCheck(output);
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, p, stream);
} else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, p,
stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, p, stream);
}
} else {
// MX scaling
NVTE_ERROR("Not supported by the Arch < 10.0");
}
}
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
} // namespace detail
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_mxfp8.cuh
* \brief CUDA kernels to quantize to MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace quantize_kernel {
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const float *noop, float *const dbias_workspace, float *const amax_ptr,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const size_t tid_Y_colwise = 0;
const size_t tid_X_colwise = threadIdx.x;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
if constexpr (IS_DBIAS) {
#pragma unroll
for (int j = 0; j < SCALE_DIM_X; ++j) {
thread_dbias_rowwise[j] = 0.0f;
}
}
float block_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
}
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
partial_dbias_colwise += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {});
}
// If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
if (rowwise_scale_is_within_bounds) {
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<OType2, PACK_SIZE / 2> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in;
OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = in_IType[w].data.elt[e];
} else if constexpr (IS_CACHED_ACT_OP) {
in.x = in_cached[w].data.elt[2 * e];
in.y = in_cached[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
parity ^= 1;
if constexpr (IS_DBIAS) {
float thread_partial_dbias = 0.0f;
if constexpr (COLWISE_SCALING) {
thread_partial_dbias = partial_dbias_colwise;
} else {
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace quantize_kernel
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
using namespace quantize_kernel;
checkCuDriverContext(stream);
bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data();
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (use_rowwise_scaling) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
}
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);
constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
e8m0_t *const scales_rowwise_ptr =
use_rowwise_scaling ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
ScalingType scaling_type;
if (use_rowwise_scaling && (!use_colwise_scaling)) {
scaling_type = ScalingType::ROWWISE;
} else if ((!use_rowwise_scaling) && use_colwise_scaling) {
scaling_type = ScalingType::COLWISE;
} else if (use_rowwise_scaling && use_colwise_scaling) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE: {
auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
case ScalingType::COLWISE: {
auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
}
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
}
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file core_nvfp4.cuh
* \brief Core functions used in NVFP4.
*/
#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <limits>
#include "../../common.h"
#include "../../util/curanddx.hpp"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // FP4_TYPE_SUPPORTED
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
using nvfp4_scale_t = fp8e4m3;
namespace quantization_and_transposition_SF {
#if FP4_TYPE_SUPPORTED
// Used in transpose variant
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantization_and_transposition_SF
namespace quantization_SF {
#if FP4_TYPE_SUPPORTED
// Used in non-transpose variant
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantization_SF
namespace core {
#if FP4_TYPE_SUPPORTED
using namespace ptx;
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng,
// philox4x32_native_state<10>: 10 rounds of philox4_32
uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
random_uint4 = rng.generate4();
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#endif // FP4_TYPE_SUPPORTED
} // namespace core
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_nvfp4.cuh
* \brief CUDA kernels to dequantize from NVFP4.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // FP4_TYPE_SUPPORTED
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace dequantize_kernel {
#if FP4_TYPE_SUPPORTED
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // FP4_TYPE_SUPPORTED
} // namespace dequantize_kernel
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace dequantize_kernel;
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_nvfp4.cuh
* \brief CUDA kernels to cast to NVFP4.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace quantize_kernel {
using namespace ptx;
using namespace quantization_SF;
using namespace core;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 16;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16
#define DIRECT_SCALING_FACTORS_STORE 1
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType, bool COLWISE_SCALING, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0,
const float *noop, float *const amax_ptr,
const float *const nvfp4_second_stage_scale_ptr, const size_t rows,
const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool ROWWISE_SCALING = true;
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE;
static_assert(BUFF_DIM_Y >= SCALE_DIM_Y &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size
constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc =
(nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr);
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int buff_offset_in = buff * BUFF_IN_DIM;
const int buff_offset_out = buff * BUFF_OUT_DIM;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_IN_DIM;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM_Y];
IType in_colwise_IType[SCALE_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(block_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
if (colwise_scale_is_within_bounds) {
scales_colwise_e8m0[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (int it = 0; it < ITERATIONS_ROWWISE; ++it) {
const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const int shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const int shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if (rowwise_scale_is_within_bounds) {
const int scales_offset_Y =
scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = scales_offset_X_rowwise;
const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X;
scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8;
}
#else
const int shmem_scales_offset_Y =
stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise;
const int shmem_scales_offset_X = tid_X_rowwise;
const int scale_idx =
shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X;
out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8;
#endif
// Compute "correct" per-block encoding scaling factor
const float block_scale_inverse =
__fdiv_rn(S_enc, static_cast<float>(S_dec_b_fp8)); // S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out; // Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in01 = in_IType[w].data.elt[2 * e];
in23 = in_IType[w].data.elt[2 * e + 1];
} else if constexpr (IS_CACHED_ACT_OP) {
in01.x = in_cached[w].data.elt[4 * e];
in01.y = in_cached[w].data.elt[4 * e + 1];
in23.x = in_cached[w].data.elt[4 * e + 2];
in23.y = in_cached[w].data.elt[4 * e + 3];
} else {
const int j = w * PACK_SIZE + 4 * e;
in01.x = in_compute_rowwise[j];
in01.y = in_compute_rowwise[j + 1];
in23.x = in_compute_rowwise[j + 2];
in23.y = in_compute_rowwise[j + 3];
}
fp4e2m1x4 &out_quad = reinterpret_cast<fp4e2m1x4 &>(out.data.elt[e]);
ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
__builtin_assume(block_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM;
const int buff_offset_mxfp8 = buff * BUFF_IN_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset_nvfp4]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset_mxfp8]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if constexpr (ROWWISE_SCALING) {
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise;
const int scale_idx_global =
scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise;
const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) &&
(scales_offset_X_rowwise < (cols / SCALE_DIM_X))) {
using ScalesVec_t = Vec<fp8e4m3, NVFP4_SCALING_FACTORS_PER_CHUNK_ROW>;
const ScalesVec_t &scales =
*reinterpret_cast<ScalesVec_t *>(&out_rowwise_scales_sh[scale_idx_shmem]);
scales.store_to(&scales_rowwise_e4m3[scale_idx_global]);
}
}
#endif
float chunk_amax = 0.0f;
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
chunk_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(thread_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, chunk_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace quantize_kernel
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace quantize_kernel;
using namespace ptx;
checkCuDriverContext(stream);
constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr;
NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated.");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = output->scale_inv.shape[1];
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast<fp8e4m3 *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_e8m0_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const ScalingType scaling_type =
use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const nvfp4_second_stage_scale_ptr =
reinterpret_cast<const float *>(output->scale.dptr);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const DType output_data_type =
use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, 4);
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t);
constexpr size_t in_mem = buff_size_aligned_in;
const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4;
const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0;
const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales;
const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0;
const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem +
out_rowwise_scales_mem + out_colwise_scales_mem +
TMA_SHMEM_ALIGNMENT;
const size_t dshmem_size = in_mem + out_mem;
switch (scaling_type) {
case ScalingType::ROWWISE: {
auto kernel =
quantize_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_transpose_nvfp4.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace quantize_transpose_kernel {
using namespace quantization_and_transposition_SF;
using namespace core;
using namespace ptx;
#if FP4_TYPE_SUPPORTED
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
// Each call generates 4x uint32_t random numbers
constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4;
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr,
nvfp4_scale_t *const scales_t_ptr, const float *noop,
const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
int rnd_idx = 0;
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr,
nvfp4_scale_t *const scales_t_ptr, const float *noop,
const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr size_t BLOCK_DIM = 16;
constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2
constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8
constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile
constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
const size_t warp_id = threadIdx.x / 32;
const size_t lane_id = threadIdx.x % 32;
float thread_amax = 0.0f;
const size_t block_in_warp = lane_id / BLOCKS_PER_WARP;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
__shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1];
// Helper function for warp reduction
auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float {
#pragma unroll
for (int delta = 8; delta >= 1; delta /= 2) {
float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta);
thread_amax = fmaxf(thread_amax, other_amax);
}
return thread_amax;
};
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
#pragma unroll
for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
const size_t block_in_tile_y = block_iter;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
for (int elem = 0; elem < BLOCK_DIM; elem += 2) {
const size_t elem_0_row = block_iter * BLOCK_DIM + elem;
const size_t elem_1_row = elem_0_row + 1;
const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
const size_t elem_1_col = elem_0_col;
const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col;
const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col;
IType2 val_2x;
val_2x.x = in_sh[shmem_offset_0];
val_2x.y = in_sh[shmem_offset_1];
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x);
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else {
for (int elem = 0; elem < BLOCK_DIM; ++elem) {
const size_t elem_row = block_iter * BLOCK_DIM + elem;
const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
// Bounds checking
const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows);
const bool col_out_of_bounds = (block_offset_X + elem_col >= cols);
if (!row_out_of_bounds && !col_out_of_bounds) {
const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col;
float elt = static_cast<float>(in_sh[shmem_offset]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
}
// Warp reduction to get block amax
block_amax = warp_reduce_amax(thread_amax, block_in_warp);
if (lane_id == 0 || lane_id == 16) {
block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads();
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 3. Scale elements
// Load data in
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
}
} else {
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// // Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = tid_X_rowwise;
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantize_transpose_kernel
template <bool use_2d_quantization>
void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_kernel;
using namespace ptx;
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
kernel = quantize_transpose_nvfp4_2D_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment