Unverified Commit 0e80c847 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Split cast/gated kernels by scaling mode (#2248)



* Separated gated and dequantize kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Separated quantize, dequantize and gated functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed persistent lint issues
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added missing compute capability 10.0 check for Quantize FP8 TMA kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed the issue which was added again by autofix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Removed unsupported template arguments in NVFP4 quantize
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed undefined symbol error
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed condition
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Fixed CUDA version check
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed arch conditions order
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Small fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the PR review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Split quantize helper into two (FWD and BWD) functions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Enabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Disabled fast math for activations by default
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 490a5f41
......@@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
......@@ -336,8 +336,7 @@ option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --u
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu)
activation/swiglu.cu)
endif()
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
......
......@@ -14,26 +14,17 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
......@@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
......@@ -63,8 +51,7 @@ template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
......
......@@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -10,36 +10,20 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "dispatch/dequantize.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/transpose.h"
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
......@@ -59,15 +43,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
......@@ -77,87 +54,17 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, output, dbias, workspace, nullptr, stream);
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output),
stream);
}
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
......@@ -166,12 +73,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
......@@ -184,9 +86,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file common.cuh
* \brief Common functions in quantize.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace common {
inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
return isFullTile;
}
inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
namespace kernel {
constexpr size_t THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) {
return;
}
const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace common
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize.cuh
* \brief Dequantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
fp8::dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
mxfp8::dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
nvfp4::dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated.cuh
* \brief Gated dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
namespace transformer_engine {
namespace dispatch {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
const Tensor input = *convertNVTETensorCheck(nvte_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim() / 2;
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols,
"Wrong output shape. Expected (after flattening) [*, ", cols, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
} else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
Tensor dummy_grad_tensor;
mxfp8::quantize_gated</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
const Tensor &grad = *(convertNVTETensorCheck(nvte_grad));
const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(grad, "grad");
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ",
gated_input.flat_last_dim(), ".");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(grad.flat_last_dim() == cols,
"Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [",
rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape,
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape,
", output shape: ", output->data.shape, ".");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
} else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
mxfp8::quantize_gated</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize.cuh
* \brief Quantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_ACT) {
cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, float, ParamOP, OP>(
*input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
mxfp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output,
NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *grad_tensor = convertNVTETensorCheck(grad);
const Tensor *input_tensor = convertNVTETensor(input);
Tensor *output_tensor = convertNVTETensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT) {
cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, float, ParamOP, OP>(
*grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream);
}
} else if (output_tensor->has_data()) {
fp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*grad_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = grad_tensor->flat_first_dim();
int32_t cols = grad_tensor->flat_last_dim();
auto dtype = grad_tensor->dtype();
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
quantize_transpose_vector_blockwise_fp4(
/*input=*/grad_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_fp8.cuh
* \brief CUDA kernels to dequantize from FP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
struct DequantizeParam {
const float *scale_inv;
};
__device__ inline float dequantize_func(float value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
}
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, DequantizeParam, dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated_fp8.cuh
* \brief CUDA kernels to cast to FP8 with gated activations.
*/
#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_
#define TRANSFORMER_ENGINE_GATED_FP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace transformer_engine {
namespace dispatch {
namespace fp8 {
namespace kernel {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X;
constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t BUFFER_DIM_Y = 32;
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
static_assert(ITERATIONS >= 1);
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0;
constexpr size_t in_act_mem = buff_size_aligned_in;
constexpr size_t in_gate_mem = buff_size_aligned_in;
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act = reinterpret_cast<const uint64_t *>(&tensor_map_output_act);
const uint64_t *TMAP_output_gate = reinterpret_cast<const uint64_t *>(&tensor_map_output_gate);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh,
TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate,
chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
} else {
copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh,
TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_BWD) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y,
in_transaction_size, &mbar[next_it], is_master_thread);
} else {
copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x,
chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, in_transaction_size,
&mbar[next_it], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_sh_curr));
if constexpr (IS_BWD) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_sh_curr));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace kernel
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
using namespace kernel;
checkCuDriverContext(stream);
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act{};
alignas(64) CUtensorMap tensor_map_output_gate{};
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
auto kernel = cast_fp8_gated_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_size));
kernel<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP, DActOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), p, stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace fp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_
This diff is collapsed.
......@@ -4,36 +4,27 @@
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_kernels.cuh
* \brief CUDA kernels to cast from MXFP8.
/*! \file dequantize_mxfp8.cuh
* \brief CUDA kernels to dequantize from MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include <transformer_engine/transformer_engine.h>
namespace transformer_engine {
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
namespace dequantization {
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace dequantize_kernel {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
......@@ -228,29 +219,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace dequantize_kernel
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
using namespace dequantize_kernel;
bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream);
......@@ -334,113 +306,8 @@ void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization
namespace detail {
void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
dequantization::fp8_dequantize(input, output, stream);
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace detail
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file core_nvfp4.cuh
* \brief Core functions used in NVFP4.
*/
#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <limits>
#include "../../common.h"
#include "../../util/curanddx.hpp"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // FP4_TYPE_SUPPORTED
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
using nvfp4_scale_t = fp8e4m3;
namespace quantization_and_transposition_SF {
#if FP4_TYPE_SUPPORTED
// Used in transpose variant
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantization_and_transposition_SF
namespace quantization_SF {
#if FP4_TYPE_SUPPORTED
// Used in non-transpose variant
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantization_SF
namespace core {
#if FP4_TYPE_SUPPORTED
using namespace ptx;
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng,
// philox4x32_native_state<10>: 10 rounds of philox4_32
uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
random_uint4 = rng.generate4();
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#endif // FP4_TYPE_SUPPORTED
} // namespace core
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_nvfp4.cuh
* \brief CUDA kernels to dequantize from NVFP4.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // FP4_TYPE_SUPPORTED
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace dequantize_kernel {
#if FP4_TYPE_SUPPORTED
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // FP4_TYPE_SUPPORTED
} // namespace dequantize_kernel
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace dequantize_kernel;
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment