Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#include <transformer_engine/transformer_engine.h>
namespace transformer_engine {
struct MatmulConfig {
NVTETensor bias_tensor = nullptr;
NVTETensor dbias_tensor = nullptr;
bool with_gelu_epilogue = false;
bool with_dgelu_epilogue = false;
NVTETensor epilogue_aux_tensor = nullptr;
bool use_split_accumulator = false;
int sm_count = 0;
static constexpr size_t attr_sizes[] = {
sizeof(NVTETensor), // bias_tensor
sizeof(NVTETensor), // dbias_tensor
sizeof(bool), // with_gelu_epilogue
sizeof(bool), // with_dgelu_epilogue
sizeof(NVTETensor), // epilogue_aux_tensor
sizeof(bool), // use_split_accumulator
sizeof(int) // sm_count
};
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
......@@ -15,23 +15,58 @@
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <cstdint>
#include <mutex>
#include <vector>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "./config.h"
#ifndef __HIP_PLATFORM_AMD__
#include "cutlass_grouped_gemm.cuh"
#include "./cutlass_grouped_gemm.cuh"
#endif
#ifndef __HIP_PLATFORM_AMD__
namespace {
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__ __constant__ float one_device;
__device__ __constant__ float zero_device;
inline float *GetScalarOne() {
static std::once_flag init_flag;
std::call_once(init_flag, []() {
float one = 1.0f;
NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float)));
});
// return address by cudaGetSymbolAddress
float *dev_ptr;
NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), one_device));
return dev_ptr;
}
inline float *GetScalarZero() {
static std::once_flag init_flag;
std::call_once(init_flag, []() {
float zero = 0.0f;
NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float)));
});
// return address by cudaGetSymbolAddress
float *dev_ptr;
NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast<void **>(&dev_ptr), zero_device));
return dev_ptr;
}
__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; }
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
......@@ -91,6 +126,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T;
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
// Unscaled or FP8 tensor scaling
......@@ -111,10 +150,32 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(A.scaling_mode)) {
// MXFP8
if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.lda % 16 == 0,
"Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
}
} else if (nvfp4) {
// NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode),
"Input A has unsupported combination of recipe and layout");
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout.
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k;
} else if (mxfp8) {
// MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
......@@ -141,7 +202,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.lda % 16) == 0,
"Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
"Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK((m % 8) == 0,
......@@ -170,10 +231,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(B.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.ldb % 16 == 0,
"Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
}
} else if (nvfp4) {
if (is_B_transposed) {
NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
"Input B has unsupported combination of recipe and layout");
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout.
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
} else if (mxfp8) {
if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
......@@ -238,7 +315,7 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
float alpha, float beta, bool use_split_accumulator, int math_sm_count,
const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count,
int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
cudaStream_t stream) {
// Tensor dims in row-major order
......@@ -277,6 +354,49 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);
// Update scaling factors with NVFP4 tensor scales
// TODO: Check whether scales are on CPU/GPU or add API to control.
// Currently scales are assumed to be on CPU when amax is provided
// and on GPU when not provided, but this is brittle.
if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
// Reserve some workspace for alpha scale
NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
workspaceSize, " bytes remaining.");
workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
float *new_alpha_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
// Update alpha scale on device
// Note: Compute NVFP4 tensor scales based on amaxes and then
// divide from alpha scale. This way we only need to apply NVFP4
// tensor scales in matmul output, instead of in matmul inputs.
float old_alpha = *reinterpret_cast<const float *>(alpha); // Assumed to be on CPU
TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
old_alpha, new_alpha_tensor.data(), stream);
alpha = new_alpha_ptr;
// Make sure beta scale is on device
float old_beta = *reinterpret_cast<const float *>(beta); // Assumed to be on CPU
if (old_beta == 0) {
beta = GetScalarZero(); // Device constant memory
} else if (old_beta == 1) {
beta = GetScalarOne(); // Device constant memory
} else {
// Move beta to workspace
NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ",
workspaceSize, " bytes remaining.");
workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes
float *new_beta_ptr = reinterpret_cast<float *>(&workspace_ptr[workspaceSize]);
set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta);
NVTE_CHECK_CUDA(cudaGetLastError());
beta = new_beta_ptr;
}
}
const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
......@@ -287,16 +407,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP4 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP4 input to GEMM requires inverse of scale!");
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) {
if ((use_fp8 || use_fp4) && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!");
if (is_fp4_dtype(outputD->data.dtype)) {
NVTE_ERROR("FP4 GEMM output is not supported!");
}
if (use_fp4 && (D_type == CUDA_R_16F)) {
NVTE_ERROR("FP4 GEMM does not support FP16 output!");
}
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
......@@ -336,12 +463,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&math_sm_count, sizeof(math_sm_count)));
}
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4
// as appropriate. Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode);
if (use_fp8 || use_fp4) {
// Fast accumulation is only supported for FP8.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode)));
......@@ -350,7 +479,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif // CUBLAS_VERSION >= 120800
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) {
void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -363,7 +492,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#endif // CUBLAS_VERSION >= 120800
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
} else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
......@@ -388,6 +517,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#else
NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
// make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
cublasDataType_t scale_type = CUDA_R_32F;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
// Set pointer mode: alpha and beta are both device pointers
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
......@@ -520,14 +677,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
#else
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
......@@ -554,6 +708,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
}
// align the workspace to 256 B
const int required_alignment = 256;
const auto original_workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
uint8_t *aligned_workspace_ptr =
reinterpret_cast<uint8_t *>(workspace) + required_alignment - original_workspace_alignment;
workspaceSize = workspaceSize - required_alignment + original_workspace_alignment;
const auto new_workspace_alignment =
_getAlignment(reinterpret_cast<uintptr_t>(aligned_workspace_ptr));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
......@@ -561,7 +723,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
......@@ -570,8 +731,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
NVTE_CHECK(new_workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ",
new_workspace_alignment);
const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
......@@ -582,16 +744,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&alpha), /* alpha */
param.A, /* A */
Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */
C, /* C */
Cdesc, D, /* D */
Ddesc, &heuristicResult.algo, /* algo */
workspace, /* workspace */
workspaceSize, stream)); /* stream */
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */
param.A, /* A */
Adesc, param.B, /* B */
Bdesc, beta, /* beta */
C, /* C */
Cdesc, D, /* D */
Ddesc, &heuristicResult.algo, /* algo */
aligned_workspace_ptr, /* workspace */
workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor
// Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
......@@ -666,13 +827,26 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
// Tensors
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
Tensor *outputD = convertNVTETensorCheck(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
// Scales
const float alpha = 1;
const float beta = accumulate ? 1 : 0;
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
}
#ifdef __HIP_PLATFORM_AMD__
const size_t A0 = inputA->flat_first_dim();
const size_t A1 = inputA->flat_last_dim();
......@@ -734,9 +908,135 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false,
nullptr, stream);
#endif //__HIP_PLATFORM_AMD__
&alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
#endif
}
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm_v2);
using namespace transformer_engine;
// Data tensors
const Tensor *A_tensor = convertNVTETensorCheck(A);
const Tensor *B_tensor = convertNVTETensorCheck(B);
const Tensor *C_tensor = convertNVTETensorCheck(C);
Tensor *D_tensor = convertNVTETensorCheck(D);
NVTE_CHECK(C_tensor == D_tensor,
"Currently nvte_cublas_gemm_v2 does not support different C and D tensors.");
// Workspace
void *workspace_ptr = nullptr;
size_t workspace_size = 0;
Tensor *workspace_tensor = convertNVTETensor(workspace);
if (workspace_tensor != nullptr) {
workspace_ptr = workspace_tensor->data.dptr;
workspace_size =
get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype);
}
// Additional config
MatmulConfig config_;
if (config != nullptr) {
config_ = *reinterpret_cast<MatmulConfig *>(config);
}
// Configure GEMM epilogue
const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue);
if (with_grad_epilogue) {
NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue,
"Invalid epilogue (bias=", config_.bias_tensor != nullptr,
", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
", dgelu=", config_.with_dgelu_epilogue, ").");
}
Tensor dummy_tensor;
Tensor *epilogue_bias_tensor = &dummy_tensor;
if (!with_grad_epilogue && config_.bias_tensor != nullptr) {
epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor);
} else if (with_grad_epilogue && config_.dbias_tensor != nullptr) {
epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor);
}
Tensor *epilogue_aux_tensor = &dummy_tensor;
if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) {
NVTE_CHECK(config_.epilogue_aux_tensor != nullptr,
"Requested epilogue (bias=", config_.bias_tensor != nullptr,
", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue,
", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor.");
epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor);
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK(*alpha == 1.0f, "alpha must be 1.0 for hip");
NVTE_CHECK(*beta == 1.0f || *beta == 0.0f, "beta must be 1.0 or 0.0 for hip");
bool accumulate = false;
if (*alpha == 1.0f and *beta == 1.0f) {
accumulate = true;
}
const size_t A0 = A_tensor->flat_first_dim();
const size_t A1 = A_tensor->flat_last_dim();
const size_t B0 = B_tensor->flat_first_dim();
const size_t B1 = B_tensor->flat_last_dim();
const int m = transa ? A0 : A1;
const int k = transa ? A1 : A0;
const int n = transb ? B1 : B0;
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
const bool use_int8 = is_int8_dtype(A_tensor->data.dtype) ||
is_int8_dtype(B_tensor->data.dtype);
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(A_tensor->data.dtype) ||
is_fp8_dtype(B_tensor->data.dtype);
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && config_.use_split_accumulator) nvte_use_hipblaslt = 1;
if ((epilogue_bias_tensor->data.dptr != nullptr) || (epilogue_aux_tensor->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, m, n, k, lda, ldb, ldd, transa, transb, with_grad_epilogue,
workspace_ptr, workspace_size, accumulate, config_.use_split_accumulator, config_.sm_count, 0, 0,
false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
} else {
hipblas_gemm(A_tensor,
B_tensor,
D_tensor,
epilogue_bias_tensor,
epilogue_aux_tensor,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
with_grad_epilogue, workspace_ptr,
workspace_size,
accumulate, config_.use_split_accumulator,
config_.sm_count,
0,
0,
false,
nullptr,
stream);
}
#else
// Launch GEMM
cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor,
transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N,
with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta,
config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream);
#endif
}
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -745,13 +1045,21 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
bool use_split_accumulator, int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm_scaled);
using namespace transformer_engine;
// Tensors
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
Tensor *outputD = convertNVTETensorCheck(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) {
NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead.");
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK(alpha == 1.0f, "alpha must be 1.0 for hip");
NVTE_CHECK(beta == 1.0f || beta == 0.0f, "beta must be 1.0 or 0.0 for hip");
......@@ -820,7 +1128,7 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
&alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
#endif
}
......@@ -838,12 +1146,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
NVTE_CHECK(
transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
......@@ -854,7 +1162,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
#endif // __HIP_PLATFORM_AMD__
#ifdef NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
......@@ -863,6 +1175,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const Tensor *inputCounter = convertNVTETensor(counter);
Tensor *wspace = convertNVTETensor(workspace);
const void *alpha_ptr = GetScalarOne();
const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero();
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling.");
......@@ -917,9 +1232,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split,
n_split, gemm_producer, inputCounter, stream);
alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, inputCounter, stream);
#endif //__HIP_PLATFORM_AMD__
#endif // NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
}
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
......@@ -948,17 +1264,59 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens
} else{
NVTE_FORCE_BLAS_MULSTREAM = false;
}
if (NVTE_FORCE_BLAS_MULSTREAM){
if (NVTE_FORCE_BLAS_MULSTREAM) {
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
detail::get_compute_stream(i % num_streams));
// Check whether GELU or dGELU epilogue is requested
Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
bool with_gelu_dgelu_epilogue =
(pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);
// Construct config
MatmulConfig config;
if (grad) {
config.dbias_tensor = bias[i];
config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
} else {
config.bias_tensor = bias[i];
config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
}
config.epilogue_aux_tensor = pre_gelu_out[i];
config.use_split_accumulator = use_split_accumulator;
config.sm_count = math_sm_count;
// Launch GEMM
const float alpha = 1.f;
const float beta = accumulate ? 1.f : 0.f;
nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
workspace[i % num_streams], &config,
detail::get_compute_stream(i % num_streams));
}
} else{
} else {
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
detail::get_compute_stream(i % num_streams), 1, 0, i % num_streams);
// Check whether GELU or dGELU epilogue is requested
Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]);
bool with_gelu_dgelu_epilogue =
(pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr);
// Construct config
MatmulConfig config;
if (grad) {
config.dbias_tensor = bias[i];
config.with_dgelu_epilogue = with_gelu_dgelu_epilogue;
} else {
config.bias_tensor = bias[i];
config.with_gelu_epilogue = with_gelu_dgelu_epilogue;
}
config.epilogue_aux_tensor = pre_gelu_out[i];
config.use_split_accumulator = use_split_accumulator;
config.sm_count = math_sm_count;
// Launch GEMM
const float alpha = 1.f;
const float beta = accumulate ? 1.f : 0.f;
nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i],
workspace[i % num_streams], &config,
detail::get_compute_stream(i % num_streams), 1, 0, i % num_streams);
}
}
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kThreadsPerWarp = 32;
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment
int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);
int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;
if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_t_reg)
: "r"(local_amax_t_reg), "r"(temp_amax_t_reg));
}
if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[1]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[2])
: "r"(a_frag[2]), "r"(a_frag[3]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[2]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}
}
template <int kN>
__device__ __host__ constexpr int NextPowerOf2() {
static_assert(kN > 0, "kN must be > 0");
// Round up to the next power of 2 by counting leading zeros.
return 1 << (32 - __builtin_clz(kN - 1));
}
template <int kNumWarps, bool kReturnPreRhtAmax, bool kReturnIdentityAmax,
bool kReturnTransposedAmax>
__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax,
const float transpose_amax, float* staging_for_pre_rht,
float* staging_for_identity, float* staging_for_transpose,
float* output_pre_rht_amax_ptr,
float* output_identity_amax_ptr,
float* output_transpose_amax_ptr, const int warpid) {
// intra-warp reduction
constexpr int kWarpSize = 32;
int local_rank = threadIdx.x % 32;
float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max<kWarpSize>(pre_rht_amax) : 0.0f;
float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max<kWarpSize>(identity_amax) : 0.0f;
float warp_transpose_amax =
kReturnTransposedAmax ? warp_reduce_max<kWarpSize>(transpose_amax) : 0.0f;
// inter-warp reduction
if (threadIdx.x % 32 == 0) {
if (kReturnPreRhtAmax) {
staging_for_pre_rht[warpid] = warp_pre_rht_amax;
}
if (kReturnIdentityAmax) {
staging_for_identity[warpid] = warp_identity_amax;
}
if (kReturnTransposedAmax) {
staging_for_transpose[warpid] = warp_transpose_amax;
}
}
__syncthreads();
constexpr int kNumWarpsPow2 = NextPowerOf2<kNumWarps>();
if (warpid == 0) {
if (kReturnIdentityAmax) {
float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f;
identity_accum = warp_reduce_max<kNumWarpsPow2>(identity_accum);
if (local_rank == 0) {
atomicMaxFloat(output_identity_amax_ptr, identity_accum);
}
}
}
if (warpid == 1) {
if (kReturnTransposedAmax) {
float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f;
transpose_accum = warp_reduce_max<kNumWarpsPow2>(transpose_accum);
if (local_rank == 0) {
atomicMaxFloat(output_transpose_amax_ptr, transpose_accum);
}
}
}
if (warpid == 2) {
if (kReturnPreRhtAmax) {
float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f;
pre_rht_accum = warp_reduce_max<kNumWarpsPow2>(pre_rht_accum);
if (local_rank == 0) {
atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum);
}
}
}
}
__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr) {
if (output_pre_rht_amax_ptr != nullptr) {
*output_pre_rht_amax_ptr = 0;
}
if (output_identity_amax_ptr != nullptr) {
*output_identity_amax_ptr = 0;
}
if (output_transpose_amax_ptr != nullptr) {
*output_transpose_amax_ptr = 0;
}
}
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input,
float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr,
uint16_t random_sign_mask, uint16_t random_sign_mask_t,
uint64_t num_rows, uint64_t row_length) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0);
static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0);
constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X;
constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp;
const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X;
extern __shared__ __align__(128) char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t* dshmem = reinterpret_cast<uint8_t*>((base_shmem_ptr + 127) & ~127ULL);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
IType* in_sh_0 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_sh_1 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_shs[2] = {in_sh_0, in_sh_1};
constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t* mbar = reinterpret_cast<uint64_t*>(dshmem);
dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y);
float* max_staging_identity = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_transpose = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_pre_rht = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
initialize_barriers<STAGES_X * STAGES_Y, THREADS_PER_CHUNK * THREADS_PER_Y>(mbar,
is_master_thread);
copy_2d_to_shared(in_shs[0], reinterpret_cast<const void*>(&tensor_map_input),
input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0],
is_master_thread);
uint32_t had_frag_i[4];
uint32_t had_frag_t[4];
get_hadamard_matrix_fragment<kReturnIdentityAmax, kReturnTransposedAmax, false, false>(
had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t);
float local_pre_rht_amax = 0.0;
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_pre_rht_amax_reg = *reinterpret_cast<uint32_t*>(&local_pre_rht_amax);
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
const int next_stage = stage + 1;
const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1;
const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y;
if (next_stage < STAGES_X * STAGES_Y) {
const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y;
const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X;
copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong
reinterpret_cast<const void*>(&tensor_map_input), input_global_offset_X,
input_global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
const size_t compute_stage_x_num =
BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp));
const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y);
const size_t in_row_stride = BUFF_DIM_X;
IType* in_sh_ptr = in_shs[stage % 2];
#pragma unroll
for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) {
const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y +
threadIdx.y * kHadamardDimension);
const int in_row_offset = row_idx_offset * in_row_stride;
#pragma unroll
for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) {
ComputeKernel<IType, kHadamardDimension, BUFF_DIM_Y, BUFF_DIM_X, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads();
}
}
}
const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp;
if constexpr (kReturnPreRhtAmax) {
unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax);
}
if constexpr (kReturnIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
}
if constexpr (kReturnTransposedAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
}
ReduceMax<kNumWarps, kReturnPreRhtAmax, kReturnIdentityAmax, kReturnTransposedAmax>(
local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity,
max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr,
output_transpose_amax_ptr, warpid);
destroy_barriers<STAGES_X * STAGES_Y>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <typename T, int kHadamardDimension, bool kComputeIdentity, bool kComputeTransposed,
bool kReturnIdentity, bool kReturnTransposed, bool kUpdateIdentityAmax,
bool kUpdateTransposeAmax, bool kOutputTrueTransposed>
__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output,
T* __restrict__ output_t, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, uint64_t num_input_rows,
uint64_t num_input_cols, float* __restrict__ amax,
float* __restrict__ amax_t, bool inverse_hadamard) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported.");
// The whole threadblock will share the same smem.
extern __shared__ __align__(16) T smem[];
// Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16.
// If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices.
int32_t tid = threadIdx.x;
int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z;
int32_t local_bx = threadIdx.y;
int32_t local_by = threadIdx.z;
// Define the register fragments
uint32_t a_frag[4]; // A matrix fragment
uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major)
uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major)
uint32_t c_frag[4]; // Result fragment
// row and col for each thread. 32 threads will work together in 128 chunk to
// load the data from global memory to shared memory.
uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t smem_index = tid;
uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension;
uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension;
bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows);
if (!load) {
// Out of bound, we are returning early. No thread divergence since the whole warp
// will return early.
return;
}
uint64_t global_offset = input_start_col + input_start_row * num_input_cols;
uint64_t global_offset_t =
kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset;
T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id;
uint32_t* smem_b32 = reinterpret_cast<uint32_t*>(base_smem);
uint4* smem_b128 = reinterpret_cast<uint4*>(base_smem);
// Asynchronously load the data from global memory to shared memory.
const uint4* input_b128 = reinterpret_cast<const uint4*>(input + global_offset);
// Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each
// 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4
// to load the data in the tensor core swizzled format.
__pipeline_memcpy_async(&smem_b128[smem_index],
&input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col],
sizeof(uint4));
__pipeline_commit(); // Commit the memcpy. Wait when we are in the computation.
if (inverse_hadamard) {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/true,
/*kInverseHadamardTransposed=*/true>(b_frag_i, random_sign_mask,
b_frag_t, random_sign_mask_t);
} else {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/false,
/*kInverseHadamardTransposed=*/false>(
b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t);
}
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
__pipeline_wait_prior(0);
__syncwarp(); // ensure all lanes finished their cp.async before reading smem
// Load the A to a_frag.
if constexpr (kComputeIdentity) {
load_matrix_16x16_from_shared<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32,
kHadamardDimension);
// 16x16 @ 16x16 leveraging all threads in the warp.
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnIdentity) {
uint4* output_b128 = reinterpret_cast<uint4*>(output + global_offset);
store_matrix_16x16_to_global<false>(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128,
num_input_cols);
}
}
if constexpr (kComputeTransposed) {
if (kComputeIdentity) {
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
} else {
load_matrix_16x16_from_shared<true>(a_frag[0],
a_frag[2], // NOTE: intentional index swapping
a_frag[1], // NOTE: intentional index swapping
a_frag[3], smem_b32, kHadamardDimension);
}
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateTransposeAmax>(
a_frag[0],
// 2,1 is used if we are using movmatrix instruction.
// Thus loading the matrix in 2,1 order will just be normal.
// This is to be compatible with the movmatrix instruction.
a_frag[2], // NOTE: intentional index swapping for transpose purpose.
a_frag[1], // NOTE: intentional index swapping for transpose purpose.
a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1],
c_frag[2], c_frag[3], local_amax_t_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnTransposed) {
uint4* output_t_b128 = reinterpret_cast<uint4*>(output_t + global_offset_t);
store_matrix_16x16_to_global<!kOutputTrueTransposed>(
c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128,
kOutputTrueTransposed ? num_input_rows : num_input_cols);
}
}
if constexpr (kUpdateIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
local_amax = warp_reduce_max<kThreadsPerWarp>(local_amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax, local_amax);
}
}
if constexpr (kUpdateTransposeAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
local_amax_t = warp_reduce_max<kThreadsPerWarp>(local_amax_t);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax_t, local_amax_t);
}
}
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+.");
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
}
} // namespace
void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform);
// Check tensors
// NOTE (frsun): This is non-intuitive, we are writing the result of
// transposed RHT to the output of rowwise.
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be simple tensor, but scaling mode is ",
to_string(output_.scaling_mode), ".");
const SimpleTensor& input = input_.data;
SimpleTensor output;
SimpleTensor& output_t = output_.data;
// Check requested outputs
const bool return_identity = output.dptr != nullptr;
const bool return_transposed = output_t.dptr != nullptr;
if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior.
return;
}
checkCuDriverContext(stream);
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
using IType = bf16;
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kThreadBlockX = 4;
// Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth.
constexpr uint64_t kThreadBlockY = 4;
uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY;
// The shared memory number of bytes required for **the whole threadblock**.
size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM;
dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY);
dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX),
DIVUP(num_rows / kHadamardDimension, kThreadBlockY));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed, kReturnTransposed,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
auto kernel =
HadamardTransformKernel<IType, kHadamardDimension, kReturnIdentity, kReturnTransposed,
kReturnIdentity, kReturnTransposed, false, false, true>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
reinterpret_cast<const IType*>(input.dptr), reinterpret_cast<IType*>(output.dptr),
reinterpret_cast<IType*>(output_t.dptr), random_sign_mask, random_sign_mask_t,
num_rows, row_length, nullptr, nullptr, false);););
NVTE_CHECK_CUDA(cudaGetLastError());
}
// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then
// get the absolute max value of the result.
void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_amax);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor& input = input_.data;
// Check amax tensors
SimpleTensor& output_pre_rht_amax = output_.amax;
SimpleTensor output_identity_amax;
SimpleTensor& output_transpose_amax = output_.columnwise_amax;
// Check requested outputs
const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr;
const bool return_identity_amax = output_identity_amax.dptr != nullptr;
const bool return_transposed_amax = output_transpose_amax.dptr != nullptr;
if (!return_identity_amax && !return_transposed_amax &&
!return_pre_rht_amax) { // Nothing to do/ill-defined behavior.
return;
}
// Zero out amaxes if needed
ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
checkCuDriverContext(stream);
using IType = bf16;
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kChunkBlockXSmall = 128;
constexpr uint64_t kChunkBlockYSmall = 128;
constexpr uint64_t kBuffDimX = 64;
constexpr uint64_t kBuffDimY = 64;
alignas(64) CUtensorMap tensor_map_input{};
create_2D_tensor_map(
/*tensorMap=*/tensor_map_input,
/*tensor=*/input,
/*globalY=*/num_rows,
/*globalX=*/row_length,
/*shmemY=*/kBuffDimY,
/*shmemX=*/kBuffDimX,
/*stride_elems=*/row_length,
/*offset_elems=*/0,
/*type_num_bits=*/sizeof(IType) * 8,
/*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B);
constexpr uint64_t kThreadBlockX = 4;
constexpr uint64_t kThreadBlockY = 1;
constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY;
dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY);
dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed_amax, kReturnTransposedAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity_amax, kReturnIdentityAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_pre_rht_amax, kReturnPreRhtAmax,
// *2 for ping-pong
size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType);
size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) *
(kChunkBlockYSmall / kBuffDimY);
size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes = (shmem_bytes + 128);
auto kernel = HadamardAmaxTmaKernel<
IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY,
kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
tensor_map_input, reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr), random_sign_mask,
random_sign_mask_t, num_rows, row_length);)));
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace transformer_engine
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform);
using namespace transformer_engine;
hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_amax);
using namespace transformer_engine;
hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
namespace transformer_engine {
namespace detail {
namespace {
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());
using namespace cute;
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
// calculate the global encode scale factor for a given global amax.
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float kFP8E4M3Max = 448.0f;
constexpr float kFP4E2M1Max = 6.0f;
// If scale is infinity, return max value of float32
float global_encode_scale = cutlass::minimum_with_nan_propagation<float>{}(
kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits<float>::max());
// If global amax is 0 or infinity, return 1
return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale;
}
template <class ElementA,
class ElementB,
class ASmemLayout,
class BSmemLayout>
struct SharedStorage {
static constexpr int AccumulatorPipelineStageCount = 16;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
uint32_t tmem_base_ptr;
struct TensorStorage : cute::aligned_struct<128, _1> {
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
};
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \
"{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
"}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1]));
#else
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16>
StochasticNumericConverter(cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const *rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr = reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr = reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr = reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TA, class AStride, class ASmemLayout, class TmaLoadA,
class TB, class BStride, class BSmemLayout, class TmaLoadB,
class TC, class CStride, class CSmemLayout,
class TSFC,
class TiledMMA,
bool kEnableStochasticRounding = false>
__global__ static
void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a,
TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b,
TC * C, CStride dC, CSmemLayout ,
TSFC * SFC,
TiledMMA mma,
float const* global_amax,
const size_t* rng_state)
{
using namespace cute;
using X = Underscore;
// static constexpr bool kApplyStochasticRounding = true;
using ElementAccumulator = float;
static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static constexpr uint32_t kTmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr int kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v<TB>);
static constexpr int AccumulatorPipelineStageCount = 16;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using TmemAllocator = cute::TMEM::Allocator1Sm;
static constexpr int VectorSize = 16;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<CSmemLayout>::value);
// Represent the full tensors
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16));
Tensor mC = make_tensor(cute::subbyte_iterator<TC>(C), make_shape(M,N), dC); // (M,N)
auto sfc_shape = make_shape(
M,
make_shape( make_shape(Int<16>{}, _4{}), N / 64 )
);
auto sfc_stride = make_stride(
N / 16,
make_stride( make_stride(_0{}, _1{}), _4{} )
);
auto sfc_layout = make_layout(sfc_shape, sfc_stride);
Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout);
auto cluster_shape = Shape< _1, _1, _1>{};
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
const int K_TILE_MAX = min(N, K) / 64;
uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile);
uint32_t tiles_in_n = (N + 64 - 1) / 64;
uint32_t linear_tile_idx = blockIdx.x;
uint32_t tile_idx_m = linear_tile_idx % tiles_in_m;
uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
auto mainloop_tiler = Shape<_128,_16,_64>{};
auto epilogue_tiler = Shape<_128,_64,_64>{};
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{});
Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Allocate SMEM
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, ElementAccumulator,
128, 64,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster);
using TiledMmaEpilogue = decltype(mma_epilogue);
Tensor tCgA = thr_mma.partition_A(gA_mk);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{}));
auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler));
auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma,
Int<AccumulatorPipelineStageCount>{}));
auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue,
Int<AccumulatorPipelineStageCount / 4>{}));
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] = tma_partition(tma_load_a,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
auto [tBgB, tBsB] = tma_partition(tma_load_b,
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
int warp_idx = cutlass::canonical_warp_idx_sync();
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7);
if (is_epilogue_warp && elect_one_sync()) {
cute::prefetch(raw_pointer_cast(global_amax));
}
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop,
mainloop_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator,
accumulator_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
if (is_dma_warp) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0));
}
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
bool is_first_wave = linear_tile_idx == blockIdx.x;
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_,tile_idx_m,_);
int k_tile = 0;
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
CUTE_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) {
int k_tile_idx_n = tile_idx_n + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage));
}
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; )
{
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_,_,_,read_stage);
auto tCrB_nk = tCrB(_,_,0,0);
CUTE_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block)
{
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTE_UNROLL
for (int i = 0; i < 4; i++) {
auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i);
gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} else if (is_epilogue_warp) {
const float global_amax_val = *global_amax;
static constexpr int FragmentSize = 256 / sizeof_bits_v<TC>;
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int thread_idx = threadIdx.x % 128;
Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N)
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{}));
auto tiled_r2g = make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TC>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(thread_idx);
auto thr_r2g = tiled_r2g.get_slice(thread_idx);
// NVFP4 non-E8 recipe constants and global scales
static constexpr float fp4_max = 6.0f;
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile);
Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index());
Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrC = make_tensor<TC>(shape(tDgC));
Tensor tTR_rAcc_frag = recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrC_frag = recast<cutlass::Array<TC, FragmentSize>>(coalesce(tDrC));
Tensor src = thr_r2g.retile_S(tDrC);
Tensor dst = thr_r2g.retile_D(tDgC);
Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout(
make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{})
));
Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC));
Tensor tDrSFC = make_tensor<TSFC>(shape(tDgSFC));
static constexpr int NumVecs = size(tDgC) / VectorSize;
Tensor tC_rRowSFD_frg = recast<cutlass::Array<TSFC, NumVecs>>(tDrSFC);
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>, true> amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// TMEM_LOAD
copy(tiled_t2r, tDtC, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32.
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
// Initialize RNG for tile
const size_t rng_sequence
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = uint4{0, 0, 0, 0};
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = dist.generate4(rng);
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v],
acc_scale
),
reinterpret_cast<cutlass::Array<uint32_t, 4>*>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TC, ElementAccumulator, VectorSize>{}(cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(compute_frgs[v], acc_scale));
}
}
copy(tiled_r2g, src, dst);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ntt_w_sfc(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 2048)
{
using namespace cute;
// Define shapes (dynamic)
auto M = static_cast<int>(m);
auto N = static_cast<int>(n);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, m); // (dM,dK)
auto dB = make_stride(Int<1>{}, 16); // (dN,dK)
auto dC = make_stride(n, Int<1>{}); // (dM,dN)
auto cga_shape = Shape< _1, _1, _1>{};
auto cga_tile_shape = Shape<_128,_16,_16>{};
auto cluster_tile_mainloop = Shape<_128,_16,_64>{};
// Construct the MMA
auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, float,
128, 16,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cga_tile_shape));
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes
constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB);
constexpr int kReservedBytes = 256; // Reserve for barriers and other uses
constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE)
auto sC = Layout<_1>{}; // XXX Dummy
// Create GMEM tensors
Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N)
Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16)
// Create the TiledCopy
auto tma_load_a = make_tma_copy_A_sm100(
SM90_TMA_LOAD{},
tensorA,
sA(_,_,_,0),
cluster_tile_mainloop,
mma);
auto tma_load_b = make_tma_copy_B_sm100(
SM90_TMA_LOAD{},
tensorB,
sB(_,_,_,0),
cga_tile_shape,
mma);
// Assert checks on tile sizes -- no predication
NVTE_CHECK(M % size<0>(cga_tile_shape) == 0,
"Inner dimension must be divisible by ", static_cast<size_t>(size<0>(cga_tile_shape)), " but got ", M, ".");
NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0,
"Outer dimension must be divisible by ", 4 * static_cast<size_t>(size<1>(cga_tile_shape)),
" but got ", N, ".");
uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size));
tiles = (tiles < sm_count) ? tiles : sm_count;
dim3 dimBlock(256);
dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape));
dim3 dimGrid(tiles, 1, 1);
int smem_size = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
auto* kernel_ptr = &rht_gemm_device<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape),
TA, decltype(dA), decltype(sA), decltype(tma_load_a),
TB, decltype(dB), decltype(sB), decltype(tma_load_b),
TC, decltype(dC), decltype(sC),
TSFC,
decltype(mma),
kEnableStochasticRounding>;
bool status = cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape,
A, dA, sA, tma_load_a,
B, dB, sB, tma_load_b,
C, dC, sC,
SFC,
mma, global_amax,
rng_state);
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ttt_wrapper(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 1024)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>(
n, m,
A, B, C,
SFC, global_amax,
rng_state,
sm_count, stream,
k_tile_size);
}
} // namespace
} // namespace detail
// clang-format on
void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_,
const Tensor &hadamard_matrix_,
QuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise);
// Check input and output tensors
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor &input = input_.data;
SimpleTensor &global_amax = output_.amax;
SimpleTensor &output_t = output_.data;
SimpleTensor &scale_inv_t = output_.scale_inv;
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (quant_config.rng_state != nullptr) {
Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TC = cutlass::float_e2m1_t;
using TSFC = cutlass::float_ue4m3_t;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Hadamard matrix must be BF16 tensor, but scaling mode is ",
to_string(hadamard_matrix_.scaling_mode), ".");
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t ndim = input.shape.size();
const size_t n = input.shape[ndim - 1];
size_t m = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
m *= input.shape[i];
}
auto sm_count = transformer_engine::cuda::sm_count();
NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");
int k_tile_size = 1024;
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
}
} // namespace transformer_engine
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise);
using namespace transformer_engine;
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
hadamard_transform_cast_fusion_columnwise(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream);
}
......@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU,
SRELU,
SREGLU,
CLAMPED_SWIGLU
};
/*! \brief Computes the GeLU activation of the input.
......@@ -173,6 +174,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -230,6 +251,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -67,6 +67,11 @@ class CommOverlapCore {
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
private:
void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size,
int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm);
public:
CommOverlapCore() {} // dummy constructor for exposing type to Python
......@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual ~CommOverlapCore();
void *get_ubuf_dptr() { return _ubuf.dptr(); }
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) {
NVTE_ERROR("Operation is not implemented.");
}
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
int get_tp_size() { return _tp_size; }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
......@@ -150,6 +164,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm);
public:
CommOverlapBase() {} // dummy constructor for exposing type to Python
......@@ -228,6 +246,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate);
public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
......@@ -241,6 +263,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual ~CommOverlapP2PBase();
void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) override;
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
......
......@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
......@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
......@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
......@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
......@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -15,9 +15,76 @@
#ifdef __cplusplus
extern "C" {
#endif
#endif // __cplusplus
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
/*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
enum NVTEMatmulConfigAttribute {
/*! Bias tensor
*
* If provided, the bias tensor is applied in the GEMM epilogue.
*/
kNVTEMatmulConfigBiasTensor = 0,
/*! Bias gradient tensor
*
* If provided, the bias gradient tensor will be filled in the GEMM epilogue.
*/
kNVTEMatmulConfigDBiasTensor = 1,
/*! Whether to compute GELU in GEMM epilogue. */
kNVTEMatmulConfigWithGELUEpilogue = 2,
/*! Whether to compute GELU backward in GEMM epilogue. */
kNVTEMatmulConfigWithDGELUEpilogue = 3,
/*! Auxilliary tensor for GEMM epilogue.
*
* For GELU, this will be filled with the GELU input. For GELU
* backward, this is expected to already be filled with the GELU
* input.
*/
kNVTEMatmulConfigEpilogueAuxTensor = 4,
/*! Whether to use split accumulator for FP8 GEMM. */
kNVTEMatmulConfigUseSplitAccumulator = 5,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEMatmulConfigSMCount = 6,
kNVTEMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
......@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = alpha * op(A) * op(B) + beta * C`
*
* \param[in] transa Whether to transpose A matrix.
* \param[in] transb Whether to transpose B matrix.
* \param[in] alpha Scaling factor applied to matmul output.
* \param[in] A A matrix.
* \param[in] B B matrix.
* \param[in] beta Scaling factor applied to C matrix.
* \param[in] C C matrix.
* \param[out] D Output matrix.
* \param[in] workspace Workspace tensor.
* \param[in] config Additional configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
* allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
......@@ -133,9 +223,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
......@@ -160,7 +250,9 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
#ifdef __cplusplus
} // extern "C"
#endif
#endif // __cplusplus
#ifdef __cplusplus
/*! \namespace transformer_engine
*/
......@@ -178,6 +270,89 @@ constexpr int num_batchgemm_streams = 1;
void nvte_cublas_handle_init();
/*! \struct MatmulConfigWrapper
* \brief C++ wrapper for NVTEMatmulConfig.
*/
class MatmulConfigWrapper {
public:
MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {}
MatmulConfigWrapper(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~MatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEMatmulConfig.
*
* \return NVTEMatmulConfig held by this MatmulConfigWrapper.
*/
operator NVTEMatmulConfig() const noexcept { return config_; }
/*! \brief Set bias tensor. */
void set_bias_tensor(NVTETensor bias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set bias gradient tensor. */
void set_dbias_tensor(NVTETensor dbias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void set_with_gelu_epilogue(bool with_gelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue,
&with_gelu_epilogue, sizeof(bool));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void set_with_dgelu_epilogue(bool with_dgelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue,
&with_dgelu_epilogue, sizeof(bool));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor,
&epilogue_aux_tensor, sizeof(NVTETensor));
}
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void set_use_split_accumulator(bool use_split_accumulator) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator,
&use_split_accumulator, sizeof(bool));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void set_sm_count(int sm_count) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int));
}
private:
/*! \brief Wrapped NVTEMatmulConfig. */
NVTEMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_GEMM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file hadamard_transform.h
* \brief Functions for Hadamard transforms.
*/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Perform a randomized Hadamard transform on the input tensor.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the absolute maximum reduction on the input tensor with/without
* randomized hadamard transform. The rowwise result is the absolute maximum
* of the input tensor. The columnwise result is the absolute maximum of the
* input tensor transposed and applied randomized hadamard transformation.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the columnwise hadamard transform cast fusion.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] hadamard_matrix Hadamard matrix.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
......@@ -124,6 +124,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it
* not natively supported by cublasLt on architectures other than Hopper.
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
* */
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -73,6 +73,7 @@ enum NVTETensorParam {
kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTENumTensorParams
};
......@@ -95,10 +96,9 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
/*! Single scale per block of 16 elements consecutive in either
* rowwise or columnwise direction */
NVTE_NVFP4_1D_SCALING = 4,
NVTE_INVALID_SCALING = 100
};
......@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState = 4,
/*! Whether to use 2D block scaling for NVFP4 */
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
kNVTEQuantizationConfigNumAttributes
};
......@@ -458,6 +464,15 @@ inline bool is_fp4_dtype(const DType t) {
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
......@@ -593,6 +608,11 @@ class TensorWrapper {
return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape);
}
// Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
......@@ -617,6 +637,10 @@ class TensorWrapper {
return get_parameter(kNVTEColumnwiseScaleInv);
}
NVTEBasicTensor get_columnwise_amax() const noexcept {
return get_parameter(kNVTEColumnwiseAmax);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
......@@ -865,6 +889,24 @@ class QuantizationConfigWrapper {
&format, sizeof(Float8BlockScaleTensorFormat));
}
/*! \brief Set stochastic rounding state */
void set_rng_state(NVTETensor rng_state) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state,
sizeof(NVTETensor));
}
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization,
&nvfp4_2d_quantization, sizeof(bool));
}
/*! \brief Set whether to use stochastic rounding */
void set_stochastic_rounding(bool stochastic_rounding) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding,
&stochastic_rounding, sizeof(bool));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) {
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -65,11 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_mxfp_scaling(z->scaling_mode),
!is_mxfp8_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
......
......@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) {
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -51,11 +51,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_mxfp_scaling(z->scaling_mode),
!is_mxfp8_scaling(z->scaling_mode),
"Cudnn backend is need by mxfp scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
......
......@@ -4,10 +4,10 @@
"""This module provides predefined FP8 recipes."""
from __future__ import annotations
import warnings
import os
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from typing import Any, Literal, Optional, Union, Callable, NamedTuple
from dataclasses import field
from pydantic.dataclasses import dataclass
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -23,9 +23,12 @@ class _FormatHelper(NamedTuple):
class Format(Enum):
"""
Supported FP8 formats.
Supported FP4 formats.
Values
------
E2M1 :
All FP4 tensors are in e2m1 format
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
......@@ -35,6 +38,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
"""
E2M1 = _FormatHelper(max_fwd=6, max_bwd=6)
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
......@@ -42,9 +46,13 @@ class Format(Enum):
@dataclass(frozen=True)
class MMParams:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""Matrix multiplication options.
Parameters
----------
use_split_accumulator : bool, default = `True`
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
use_split_accumulator: bool = True
......@@ -55,10 +63,24 @@ class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
"""
power_2_scale: bool = False
amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
)
class Recipe:
......@@ -66,6 +88,10 @@ class Recipe:
Base recipe class.
"""
def nvfp4(self):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
......@@ -86,6 +112,10 @@ class Recipe:
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
def custom(self):
"""Whether the given recipe is custom."""
return isinstance(self, CustomRecipe)
@dataclass()
class DelayedScaling(Recipe):
......@@ -131,7 +161,7 @@ class DelayedScaling(Recipe):
where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
GPU maintains local amaxes and scaling factors. To ensure results are
......@@ -139,7 +169,7 @@ class DelayedScaling(Recipe):
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
......@@ -184,6 +214,7 @@ class DelayedScaling(Recipe):
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"reduce_amax={self.reduce_amax}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
......@@ -201,10 +232,11 @@ class Float8CurrentScaling(Recipe):
pass.
"""
use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1"
fp8_format: Format = Format.HYBRID
fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
......@@ -213,9 +245,6 @@ class Float8CurrentScaling(Recipe):
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8CurrentScaling."
def __repr__(self) -> str:
return (
......@@ -334,6 +363,7 @@ class Float8BlockScaling(Recipe):
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8BlockScaling."
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return (
......@@ -351,3 +381,134 @@ class Float8BlockScaling(Recipe):
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
@dataclass()
class NVFP4BlockScaling(Recipe):
"""
Use the NVFP4 scaling strategy.
This is a 2-level block scaling strategy. In level 1, each group of
16 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E4M3 (4 bits of exponent,
3 bits of mantissa). In level 2, a global per tensor FP32 scaling
factor is used to scale the entire tensor.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
The default NVFP4 training recipe implements 3 techniques for quantizing
to a narrow format (4-bit):
- For weight tensors a variant of the NVFP4 quantization is used,
where a single scaling factor is shared by a 2D block of 16x16 elements.
- When quantizing gradients, stochastic rounding is applied to avoid the bias
introduced by quantization. With this, values are rounded probabilistically
to one of their two nearest representable numbers, with probabilities
inversely proportional to their distances.
- When quantizing inputs and gradients, random Hadamard transforms are applied
(16x16 Hadamard matrix) to smooth outliers in the tensor distributions
and make them easier to represent accurately in NVFP4.
These techniques are described more comprehensively in the NVFP4 paper titled
'Pretraining Large Language Models with NVFP4' (https://arxiv.org/abs/2509.25149v1).
Parameters
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default = `False`
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default = `False`
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = `False`
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
# Configuration envvars
disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1"
disable_stochastic_rounding: bool = (
os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1"
)
disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1"
fp4_format: Format = Format.E2M1
fp8_format: Format = Format.E4M3
# Not applying quantization to attention for now
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"
# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=False,
fp4_2d_quantization=False,
)
self.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False,
stochastic_rounding=False,
fp4_2d_quantization=not self.disable_2d_quantization,
)
self.fp4_quant_bwd_grad = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=not self.disable_stochastic_rounding,
fp4_2d_quantization=False,
)
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"fp4_format={str(self.fp4_format).split('.')[1]}, "
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}, "
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
)
@dataclass()
class CustomRecipe(Recipe):
"""
Custom recipe that allows users to provide quantizer factories.
.. warning::
**EXPERIMENTAL**: Custom recipe is experimental, still under active development,
and the API is subject to change without notice. Use at your own risk.
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory: Callable[..., Any]
fp8_dpa: bool = False
fp8_mha: bool = False
def __repr__(self) -> str:
return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}"
......@@ -27,6 +27,13 @@ namespace {
constexpr int amax_kernel_threads = 512;
__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*amax_ptr = 0;
}
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
......@@ -131,7 +138,8 @@ template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) {
// Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
// Return immediately if tensor is empty
if (N == 0) {
......@@ -216,15 +224,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING ||
output.scaling_mode == NVTE_NVFP4_1D_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling or "
"NVFP4 1D scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(output.amax.numel() == 1,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape=",
output.amax.shape, ")");
NVTE_CHECK(output.amax.dptr != nullptr,
NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr,
"Output tensor for amax computation has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Output tensor for amax computation has invalid amax tensor "
......@@ -243,11 +253,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
}
// Compute amax
float *amax_ptr = reinterpret_cast<float *>(
(output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
noop_ptr, stream);); // NOLINT(*)
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel<nvec>(
reinterpret_cast<const IType *>(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr,
stream);); // NOLINT(*)
}
} // anonymous namespace
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace nvfp4_recipe {
// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0;
constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0);
// Kernel to compute alpha *= amax_A * amax_B / factor
__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A,
const float *amax_B, float *alpha_out) {
// factor is defined in the enclosing namespace
*alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv;
}
} // namespace nvfp4_recipe
} // namespace transformer_engine
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out,
cudaStream_t stream) {
NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale);
using namespace transformer_engine;
auto *tA = convertNVTETensor(inpA);
auto *tB = convertNVTETensor(inpB);
auto *tOut = convertNVTETensor(alpha_out);
void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr;
void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr;
void *alpha_ptr = tOut->data.dptr;
// check for not null pointers
NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null");
NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null");
NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null");
nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>(
alpha_in, reinterpret_cast<const float *>(amax_A_ptr),
reinterpret_cast<const float *>(amax_B_ptr), reinterpret_cast<float *>(alpha_ptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -18,7 +18,9 @@
namespace transformer_engine {
namespace {
constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32;
constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16;
constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
......@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
const int original_K = kernel_args.original_k_list[tensor_id];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
......@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + ".");
}
NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING ||
input->scaling_mode == NVTE_BLOCK_SCALING_1D ||
input->scaling_mode == NVTE_BLOCK_SCALING_2D ||
input->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
......@@ -345,176 +349,202 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
CheckInputTensor(*output, "scaling_factor_output");
auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Unsupported scaling mode for swizzling.");
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
// 1D block scaling, row-wise or colum-wise
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
const int m =
input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1];
const int k =
input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0];
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) {
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
output->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
output->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
int m, k;
if (input->has_data()) {
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else {
if (nvfp4) {
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
} else {
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
}
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) {
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
output->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
output->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
}
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const bool rowwise_swizzle = input->has_data() || nvfp4;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4;
dim3 block_size(TB_DIM, TB_DIM);
if (input->has_data()) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
if (vec_load_size == 3) vec_load_size = 1;
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) {
if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
if (vec_load_size == 3) vec_load_size = 1;
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
int original_M, original_K;
void *input_scale_inv_ptr, *output_scale_inv_ptr;
if (!nvfp4 || input->has_data()) {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
}
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 4:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
#else
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break;
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
if (input->has_columnwise_data()) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) {
}
if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 4:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
#else
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
break;
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
// 2D block scaling
} else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -650,6 +680,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// TODO(nvfp4): Add NVFP4 support.
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size();
......@@ -776,7 +808,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
* WIP (Phuong):
* - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling.
*/
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>
#include <cstdint>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace {
constexpr uint32_t WARP_SIZE = 32;
} // namespace
namespace swizzle_kernel_1d {
constexpr uint32_t WARPS_X_PER_TB = 2; // configurable
constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable
// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where
// each thread stores a single row (of four bytes).
// Example:
// lane0.row = 0x00010203
// lane1.row = 0x04050607
// lane2.row = 0x08090a0b
// lane3.row = 0x0c0d0e0f
// Becomes:
// lane0.row = 0x0004080c
// lane1.row = 0x0105090d
// lane2.row = 0x02060a0e
// lane3.row = 0x03070b0f
uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row,
const uint32_t lane,
const uint32_t active_mask) {
using cu = const uint32_t;
// Threads operate in groups of 4, and each thread stores 4 bytes at a time.
// The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes
// until we have transposed the 4x4 matrix.
cu m_0123_4567_89ab_cdef = row;
cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4);
cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240);
cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715);
cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea;
cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4);
cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410);
cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276);
cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15;
return m_048c_159d_26ae_37bf;
}
// Expands a uint32_t to a uint4 by duplicating each byte four times.
// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404}
uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) {
return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222),
__byte_perm(x, 0, 0x3333)};
}
// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data
// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors.
struct no_oob_tag_t {};
constexpr no_oob_tag_t NO_OOB_TAG;
template <typename OOBT>
void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel(
const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x,
const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride,
OOBT first_oob) {
// resolve kernel variant
constexpr bool no_oob = std::is_same_v<OOBT, no_oob_tag_t>;
static_assert(no_oob || std::is_same_v<OOBT, uint32_t>);
// load thread indices
const uint32_t lane = threadIdx.x;
__builtin_assume(lane < WARP_SIZE);
const uint32_t warp_x = threadIdx.z;
__builtin_assume(warp_x < WARPS_X_PER_TB);
const uint32_t warp_y = threadIdx.y;
__builtin_assume(warp_y < WARPS_Y_PER_TB);
// compute tile indices
const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y;
const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x;
const uint32_t in_tile_y = out_tile_x;
const uint32_t in_tile_x = out_tile_y;
// bounds check; uniform branch
if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) {
return;
}
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
// load scaling factors for this lane's initial four 1x128 tiles
uint4 sf;
if constexpr (no_oob) {
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
} else {
if ((out_tile_y < tiles_y - 1) || lane < first_oob) {
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
} else {
sf = uint4{0, 0, 0, 0};
}
}
// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
// partially swizzle the scaling factors
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4);
packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx);
// transpose 4x4 matrices of scaling factors
packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK);
// broadcast the scaling factors for sixteen 1x32 tiles
sf = broadcast_uint32_t_to_uint4(packed_exponents);
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
reinterpret_cast<uint4*>(warp_dst)[lane] = sf;
}
void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols,
cudaStream_t stream) {
NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ",
alignof(uint4), " bytes");
NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)),
"Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes");
NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors");
const uint32_t tiles_x = DIVUP(data_cols, 128u);
const uint32_t tiles_y = DIVUP(data_rows, 128u);
const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1};
const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB};
// Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales
// and a 128x4 tile in the output scales. The input scales are in transposed order.
const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4;
const uint32_t output_scale_inv_cols = tiles_x * 128 * 4;
const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float);
const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t);
const uint32_t first_oob = (input_scale_inv_cols % 128) / 4;
if (first_oob == 0) {
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG);
} else {
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob);
}
}
} // namespace swizzle_kernel_1d
namespace swizzle_kernel_2d {
constexpr uint32_t WARPS_X_PER_TB = 2; // configurable
constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable
void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel(
const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x,
const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) {
// load thread indices
const uint32_t lane = threadIdx.x;
__builtin_assume(lane < WARP_SIZE);
const uint32_t warp_x = threadIdx.z;
__builtin_assume(warp_x < WARPS_X_PER_TB);
const uint32_t warp_y = threadIdx.y;
__builtin_assume(warp_y < WARPS_Y_PER_TB);
// compute tile indices
const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y;
const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x;
const uint32_t in_tile_y = out_tile_y;
const uint32_t in_tile_x = out_tile_x;
// bounds check; uniform branch
if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) {
return;
}
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = sizeof(float);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
// load scaling factor for this warp's 128x128 tile
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);
// broadcast it to four scaling factors for 1x32 tiles
sf = (sf << 1) | (sf >> 7);
sf = sf | (sf >> 16);
// broadcast it to sixteen scaling factors for 1x32 tiles
const uint4 sf4{sf, sf, sf, sf};
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
reinterpret_cast<uint4*>(warp_dst)[lane] = sf4;
}
void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols,
cudaStream_t stream) {
NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ",
alignof(float), " bytes");
NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)),
"Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes");
const uint32_t tiles_x = DIVUP(data_cols, 128u);
const uint32_t tiles_y = DIVUP(data_rows, 128u);
const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1};
const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB};
// Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales
// and a 128x4 tile in the output scales.
const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4;
const uint32_t output_scale_inv_cols = tiles_x * 128 * 4;
const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float);
const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t);
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride);
}
} // namespace swizzle_kernel_2d
void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output,
cudaStream_t stream) {
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
CheckInputTensor(*input, "block_scaling_scaling_factor_input");
CheckInputTensor(*output, "mxfp8_scaling_factor_output");
const NVTEScalingMode scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING,
"Output tensor must be an mxfp8 tensor");
NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 ||
input->data.dtype == transformer_engine::DType::kFloat8E5M2,
"Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8");
NVTE_CHECK(output->data.dtype == input->data.dtype,
"Output data must have the same dtype as input data");
NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors");
NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0,
"Output must have E8M0 scaling factors");
NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data");
NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input");
NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors");
NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix");
NVTE_CHECK(output->data.shape == input->data.shape,
"Output data must have the same shape as input data");
NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix");
NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix");
const size_t data_rows = input->data.shape[0];
const size_t data_cols = input->data.shape[1];
const size_t input_scale_inv_rows = input->scale_inv.shape[0];
const size_t input_scale_inv_cols = input->scale_inv.shape[1];
const size_t output_scale_inv_rows = output->scale_inv.shape[0];
const size_t output_scale_inv_cols = output->scale_inv.shape[1];
NVTE_CHECK(output_scale_inv_rows == DIVUP<size_t>(data_rows, 128) * 128,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows,
" rows instead.");
NVTE_CHECK(output_scale_inv_cols == DIVUP<size_t>(data_cols, 128) * 4,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols,
" columns instead.");
if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
NVTE_CHECK(input_scale_inv_rows == DIVUP<size_t>(data_cols, 128),
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_cols, 128),
" rows, but it has ", input_scale_inv_rows, " rows instead.");
NVTE_CHECK(input_scale_inv_cols == DIVUP<size_t>(data_rows, 4) * 4,
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_rows, 4) * 4,
" columns, but it has ", input_scale_inv_cols, " columns instead.");
swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows,
data_cols, stream);
} else { // scaling_mode == NVTE_BLOCK_SCALING_2D
NVTE_CHECK(input_scale_inv_rows == DIVUP<size_t>(data_rows, 128),
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_rows, 128),
" rows, but it has ", input_scale_inv_rows, " rows instead.");
NVTE_CHECK(input_scale_inv_cols == DIVUP<size_t>(data_cols, 512) * 4,
"Expected the input scaling factor matrix to have ",
DIVUP<size_t>(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols,
" columns instead.");
swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows,
data_cols, stream);
}
}
} // namespace transformer_engine
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors);
using namespace transformer_engine;
swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input),
convertNVTETensorCheck(output), stream);
}
......@@ -11,6 +11,7 @@
#include <cstring>
#include <iostream>
#include <mutex>
#include <utility>
#include "common.h"
#include "common/util/cuda_runtime.h"
......@@ -67,8 +68,12 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
case NVTE_BLOCK_SCALING_1D:
return "NVTE_BLOCK_SCALING_1D";
case NVTE_BLOCK_SCALING_2D:
return "NVTE_BLOCK_SCALING_2D";
case NVTE_NVFP4_1D_SCALING:
return "NVTE_NVFP4_1D_SCALING";
case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING";
}
......@@ -98,12 +103,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")");
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_rowwise = 32;
const size_t block_size_colwise = 32;
if (t.has_data()) {
......@@ -114,6 +118,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
......@@ -126,11 +131,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
} else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
if (t.has_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
}
}
}
......@@ -158,6 +181,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
"(expected Float32 or Byte, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else if (is_fp4_dtype(type)) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
......@@ -199,10 +242,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
"(expected Float32 or Float8E8M0, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
// Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
......@@ -507,6 +569,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
case kNVTEColumnwiseScaleInv:
t->columnwise_scale_inv = *param;
break;
case kNVTEColumnwiseAmax:
t->columnwise_amax = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
......@@ -530,6 +595,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
return t.scale_inv;
case kNVTEColumnwiseScaleInv:
return t.columnwise_scale_inv;
case kNVTEColumnwiseAmax:
return t.columnwise_amax;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
......@@ -645,6 +712,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(&config_.rng_state, buf, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment