Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -63,92 +63,170 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { ...@@ -63,92 +63,170 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
} }
/* Parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
struct GemmParam { struct GemmParam {
void *A; void *A = nullptr;
void *B; void *B = nullptr;
cublasOperation_t transA; cublasOperation_t transA = CUBLAS_OP_N;
cublasOperation_t transB; cublasOperation_t transB = CUBLAS_OP_N;
transformer_engine::DType Atype; transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
transformer_engine::DType Btype; transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
void *A_scale_inv; void *A_scale_inv = nullptr;
void *B_scale_inv; void *B_scale_inv = nullptr;
int lda; int lda = 0; // A column strides
int ldb; int ldb = 0; // B column strides
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
: A(nullptr),
B(nullptr),
transA(transA),
transB(transB),
Atype(transformer_engine::DType::kNumTypes),
Btype(transformer_engine::DType::kNumTypes),
A_scale_inv(nullptr),
B_scale_inv(nullptr),
lda(0),
ldb(0) {}
}; };
/* Populate parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
const transformer_engine::Tensor &B, const cublasOperation_t transB, const transformer_engine::Tensor &B, const cublasOperation_t transB,
const int k, const int lda, const int ldb) { int m, int n, int k) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(A.scaling_mode == B.scaling_mode, NVTE_CHECK(
"Inputs A and B to GEMM need to have the same scaling mode!"); A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!");
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret(transA, transB); GemmParam ret;
ret.lda = lda; // Transpose mode with column-major ordering
ret.ldb = ldb; bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T;
// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) { if (is_tensor_scaling(A.scaling_mode)) {
// Unscaled or FP8 tensor scaling
ret.A = A.data.dptr; ret.A = A.data.dptr;
ret.transA = transA;
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr; ret.A_scale_inv = A.scale_inv.dptr;
if (transA == CUBLAS_OP_T) { ret.lda = is_A_transposed ? k : m;
ret.Atype = A.data.dtype; if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
} else { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
if (is_fp8_dtype(ret.Atype)) { ret.A = A.columnwise_data.dptr;
int arch = cuda::sm_arch(cuda::current_device()); ret.transA = CUBLAS_OP_T;
if (arch < 100) { ret.Atype = A.columnwise_data.dtype;
// Hopper and Ada - we need to use columnwise_data and change transA ret.A_scale_inv = A.columnwise_scale_inv.dptr;
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); ret.lda = k;
ret.A = A.columnwise_data.dptr; } else {
ret.transA = CUBLAS_OP_T; NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
ret.lda = k;
}
} }
} }
} else if (is_mxfp_scaling(A.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = transA;
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
} else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k;
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.lda % 16) == 0,
"Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK((m % 8) == 0,
"Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
} else {
NVTE_ERROR("A has unsupported scaling mode");
}
// Configure B matrix
if (is_tensor_scaling(B.scaling_mode)) {
// Unscaled or FP8 tensor scaling
ret.B = B.data.dptr; ret.B = B.data.dptr;
ret.transB = transB;
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr; ret.B_scale_inv = B.scale_inv.dptr;
if (transB == CUBLAS_OP_T) { ret.ldb = is_B_transposed ? n : k;
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (is_fp8_dtype(ret.Btype)) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
int arch = cuda::sm_arch(cuda::current_device()); if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
if (arch < 100) { ret.B = B.columnwise_data.dptr;
// Hopper and Ada - we need to use columnwise_data and change transA ret.transB = CUBLAS_OP_N;
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); ret.Btype = B.columnwise_data.dtype;
ret.B = B.columnwise_data.dptr; ret.B_scale_inv = B.columnwise_scale_inv.dptr;
ret.transB = CUBLAS_OP_N; ret.ldb = k;
ret.B_scale_inv = B.columnwise_scale_inv.dptr; } else {
ret.ldb = k; NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} }
}
} else if (is_mxfp_scaling(B.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else { } else {
ret.Btype = B.data.dtype; NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = transB;
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
} else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N;
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
// Requirements from
// https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.ldb % 16) == 0,
"B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
// Observed this requirement only present for B tensor is 1D quantized.
NVTE_CHECK((n % 8) == 0,
"Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
} }
} else { } else {
// If not tensor scaling (which includes also high precision types), we need to NVTE_ERROR("B has unsupported scaling mode");
// use the proper version of data
// We leave the transA/B values as is, since Blackwell supports transposes
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
} }
return ret; return ret;
} }
...@@ -167,18 +245,33 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -167,18 +245,33 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#else // Use cublasLt #else // Use cublasLt
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>; using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
const Tensor *inputCounter, cudaStream_t stream) { // Tensor dims in row-major order
const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim();
const int B0 = inputB->flat_first_dim();
const int B1 = inputB->flat_last_dim();
// GEMM dims in column-major order
const int m = transa == CUBLAS_OP_T ? A0 : A1;
const int n = transb == CUBLAS_OP_T ? B1 : B0;
const int k = transa == CUBLAS_OP_T ? A1 : A0;
NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
const int ldd = m;
// Return immediately if GEMM is trivial // Return immediately if GEMM is trivial
if (m <= 0 || n <= 0) { if (m <= 0 || n <= 0) {
return; return;
} }
NVTE_CHECK(k > 0); NVTE_CHECK(k > 0);
const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);
void *C = outputD->data.dptr; void *C = outputD->data.dptr;
void *D = outputD->data.dptr; void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr; void *D_scale = outputD->scale.dptr;
...@@ -240,6 +333,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -240,6 +333,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
param.transA == CUBLAS_OP_N ? k : m, param.lda)); param.transA == CUBLAS_OP_N ? k : m, param.lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
param.transB == CUBLAS_OP_N ? n : k, param.ldb)); param.transB == CUBLAS_OP_N ? n : k, param.ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
...@@ -265,7 +359,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -265,7 +359,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Scaling factors. // Scaling factors.
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode; cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif #endif
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv; void *A_scale_inverse = param.A_scale_inv;
...@@ -277,8 +372,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -277,8 +372,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv); fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -287,7 +383,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -287,7 +383,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) { if (cublasLtGetVersion() <= 120803) {
...@@ -296,7 +393,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -296,7 +393,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride))); sizeof(dummy_a_vec_stride)));
} }
#endif } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090
float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
NVTE_ERROR("FP8 block scaling requires CUDA 12.9+");
#endif // CUDA_VERSION >= 12090
#endif // CUDA_VERSION >= 12080
} else { } else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + "."); to_string(inputB->scaling_mode) + ".");
...@@ -304,9 +426,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -304,9 +426,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
#endif #endif
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output // Accumulation mode not supported for FP8 output
...@@ -316,8 +438,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -316,8 +438,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( // NOTE: In all current cases where FP8 output is supported, the input is
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); // scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
#endif #endif
// For FP8 output, cuBLAS requires C_type to match bias_type and // For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16 // be FP16/BF16
...@@ -375,6 +500,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -375,6 +500,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
} }
if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
epilogue == CUBLASLT_EPILOGUE_DGELU),
"Epilogue requested outside of the available and tested cuBLAS functionality for "
"float8 block scaled GEMM");
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
...@@ -422,7 +555,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -422,7 +555,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
"Unable to find suitable cuBLAS GEMM algorithm"); "Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status); NVTE_CHECK_CUBLAS(status);
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
...@@ -494,6 +626,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -494,6 +626,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
#ifdef __HIP_PLATFORM_AMD__
const size_t A0 = inputA->flat_first_dim(); const size_t A0 = inputA->flat_first_dim();
const size_t A1 = inputA->flat_last_dim(); const size_t A1 = inputA->flat_last_dim();
const size_t B0 = inputB->flat_first_dim(); const size_t B0 = inputB->flat_first_dim();
...@@ -519,32 +652,13 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -519,32 +652,13 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR("TT layout not allowed."); NVTE_ERROR("TT layout not allowed.");
} }
#ifdef __HIP_PLATFORM_AMD__
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM"); const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){ if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0,
#else false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else
math_sm_count, 0, 0, false, nullptr, stream);
#endif
#ifdef __HIP_PLATFORM_AMD__
} else { } else {
hipblas_gemm(inputA, hipblas_gemm(inputA,
inputB, inputB,
...@@ -565,8 +679,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -565,8 +679,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
nullptr, nullptr,
stream); stream);
} }
#endif //__HIP_PLATFORM_AMD__ #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
#endif //__HIP_PLATFORM_AMD__
} }
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...@@ -596,7 +713,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -596,7 +713,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode), is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling."); "Atomic GEMM only supports delayed scaling.");
#ifdef __HIP_PLATFORM_AMD__
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
...@@ -617,32 +734,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -617,32 +734,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_ERROR("TT layout not allowed."); NVTE_ERROR("TT layout not allowed.");
} }
#ifdef __HIP_PLATFORM_AMD__
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM"); const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){ if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count,
#else m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#endif
#ifdef __HIP_PLATFORM_AMD__
} else { } else {
hipblas_gemm(inputA, hipblas_gemm(inputA,
inputB, inputB,
...@@ -663,8 +761,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -663,8 +761,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
inputCounter, inputCounter,
stream); stream);
} }
#endif //__HIP_PLATFORM_AMD__ #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
#endif //__HIP_PLATFORM_AMD__
} }
......
...@@ -17,22 +17,31 @@ ...@@ -17,22 +17,31 @@
extern "C" { extern "C" {
#endif #endif
/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) /* Quantize the tensor
* The implementation is per the microscaling format MXFP8 defined by the OCP specification: *
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor.
*
* Supported formats are:
*
* 1) MXFP8 scaling (for compute capability 10.0 or newer)
*
* The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification:
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
* *
* Supported modes of scaling (live scaling): *
* 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: * Supported modes of MXFP8 scaling (live scaling) for scaling mode NVTE_MXFP8_1D_SCALING
* a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* - the scaled output tensor * - the scaled output tensor
* - the corresponding scaling factors * - the corresponding scaling factors
* The scaling factors are computed for blocks of the shape [1,32] * The scaling factors are computed for blocks of the shape [1,32]
* (i.e., each scaling factor spans 32 contiguous elements along rows). * (i.e., each scaling factor spans 32 contiguous elements along rows).
* *
* 2) Columwise scaling (along the dim=1) computes one set of the output data. * b) Columwise scaling (along the dim=1) computes one set of the output data.
* The scaling factors are computed for blocks of the shape [32,1] * The scaling factors are computed for blocks of the shape [32,1]
* (i.e., each scaling factor spans 32 contiguous elements along columns). * (i.e., each scaling factor spans 32 contiguous elements along columns).
* *
* 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) * c) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
* computes two sets of the output data: both 1) and 2). * computes two sets of the output data: both 1) and 2).
* *
* The shape of the MX block must be specified in the 'output' argument, * The shape of the MX block must be specified in the 'output' argument,
...@@ -40,31 +49,69 @@ extern "C" { ...@@ -40,31 +49,69 @@ extern "C" {
* *
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* of the output tensor should be set to 0. * of the output tensor should be set to 0.
*
* 2) NVTE_DELAYED_TENSOR_SCALING that quantize the entire tensor
* using a single scaling factor. The absolute maximum value of the tensor should
* be precalculated either online (current scaling) or based on a tensor history
* (delayed scaling). The calls to nvte_quantize scale based on that data value.
* Note the NVTE_DELAYED_TENSOR_SCALING NVTEScalingMode is reused for online
* per tensor scaling.
*
*
* 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D
* for compute capability of at least 9.0. These modes quantize the tensor by blocks
* of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively.
*
* The supported modes are:
* a) Rowwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with identical shape to the
* input tensor.
* - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks.
* b) Columnwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with a shape equivalent to
* the transpose of the input tensor.
* - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks
* of the input tensor.
* c) Both: In which both tensors and both scales are calculated.
*
* This quantization mode includes both the calculation of the scaling factors
* per-tile and quantization of the row and/or columnwise tiles. No precalculated
* absolute max is required. The scaling factors are also rounded to powers of 2.
*/ */
/*! \brief Casts input tensor to FP8/MXFP8. /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * The type of quantized tensor in the output depends on the scaling mode of the output
* the block quantization (MXFP8) of the specified shape of the block will be used. * tensor. See file level comments.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor. * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor. * based on the value of the 'noop' tensor.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * The type of quantized tensor in the output depends on the scaling mode of the output
* the block quantization (MXFP8) of the specified shape of the block will be used. * tensor. See file level comments.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor. * \param[in,out] output Output quantized tensor.
* \param[out] noop Noop tensor. * \param[out] noop Noop tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output quantized tensor.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream);
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used. * the block quantization (MXFP8) of the specified shape of the block will be used.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_ #define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "fused_attn.h"
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -16,112 +17,63 @@ extern "C" { ...@@ -16,112 +17,63 @@ extern "C" {
/*! \brief Apply rotary positional embedding to the input tensor. /*! \brief Apply rotary positional embedding to the input tensor.
* *
* \param[in] input Input tensor for fused rope. * \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input. * \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input. * \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input. * \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input. * \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input.
* \param[in] stride_b Stride of the b dimension of input. * \param[in] stride_b Stride of the b dimension of input. (0 for thd).
* \param[in] stride_h Stride of the h dimension of input. * \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input. * \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const int s, const int b, const int h, const int d, const int d2, const NVTETensor freqs, NVTETensor output,
const int stride_s, const int stride_b, const int stride_h, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_d, const int o_stride_s, const int o_stride_b, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream); const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
* \param[in] output_grads Incoming gradient tensor for backward. * \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate. * \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of output_grads. * \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads. * \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads. * \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads. * \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads. * \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd).
* \param[in] stride_h Stride of the h dimension of output_grads. * \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads. * \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
NVTETensor input_grads, const int s, const int b, const int h, const NVTETensor freqs, NVTETensor input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
* in weight dtype. If set to false, it will compute in compute dtype.
*
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in] bool Enable if True
*/
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -42,6 +42,8 @@ struct NVTEShape { ...@@ -42,6 +42,8 @@ struct NVTEShape {
const size_t *data; const size_t *data;
/*! \brief Number of dimensions. */ /*! \brief Number of dimensions. */
size_t ndim; size_t ndim;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t owned_data[14];
}; };
/*! \struct NVTEBasicTensor /*! \struct NVTEBasicTensor
...@@ -80,8 +82,13 @@ enum NVTEScalingMode { ...@@ -80,8 +82,13 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING = 2, /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles,
NVTE_NO_SCALING = 3 which each yield a scale. The block_scaling_dim property of the quantizer
selects the granularity.
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
NVTE_INVALID_SCALING = 100
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
...@@ -129,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor); ...@@ -129,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/ */
void *nvte_tensor_columnwise_data(const NVTETensor tensor); void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
NVTEShape nvte_make_shape(const size_t *data, size_t ndim);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -281,6 +297,12 @@ enum NVTEQuantizationConfigAttribute { ...@@ -281,6 +297,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigForcePow2Scales = 0, kNVTEQuantizationConfigForcePow2Scales = 0,
/*! Small value to add to amax for numerical stability */ /*! Small value to add to amax for numerical stability */
kNVTEQuantizationConfigAmaxEpsilon = 1, kNVTEQuantizationConfigAmaxEpsilon = 1,
/*! Noop tensor (containing a scalar).
If the scalar element value = 1, quantization kernel will early exit.
This is a tensor because the flag must be on GPU in order to enable
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor = 2,
kNVTEQuantizationConfigNumAttributes kNVTEQuantizationConfigNumAttributes
}; };
...@@ -406,8 +428,9 @@ class TensorWrapper { ...@@ -406,8 +428,9 @@ class TensorWrapper {
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1}, float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1},
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, : TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr,
scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, scale_dptr, scale_inv_dptr,
nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()),
scaling_mode) {} scaling_mode) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
...@@ -523,7 +546,9 @@ class TensorWrapper { ...@@ -523,7 +546,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape shape() const noexcept { const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
...@@ -532,7 +557,9 @@ class TensorWrapper { ...@@ -532,7 +557,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape columnwise_shape() const noexcept { const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_columnwise_shape(tensor_); return nvte_tensor_columnwise_shape(tensor_);
} }
...@@ -645,7 +672,9 @@ class TensorWrapper { ...@@ -645,7 +672,9 @@ class TensorWrapper {
* \return scale_inv_shape of this TensorWrapper. * \return scale_inv_shape of this TensorWrapper.
*/ */
const NVTEShape scale_inv_shape() const noexcept { const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_scale_inv_shape(tensor_); return nvte_tensor_scale_inv_shape(tensor_);
} }
...@@ -661,12 +690,20 @@ class TensorWrapper { ...@@ -661,12 +690,20 @@ class TensorWrapper {
void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); }
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {&defaultData, 1}; static constexpr NVTEShape defaultShape = {
&defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
private: private:
NVTEShape convertShape(const NVTEShape &s) { return s; } NVTEShape convertShape(const NVTEShape &s) {
NVTEShape ret = s;
// Move the ownership rather than pointing to the parent shape.
ret.data = ret.owned_data;
return ret;
}
NVTEShape convertShape(const std::vector<size_t> &s) { return {s.data(), s.size()}; } NVTEShape convertShape(const std::vector<size_t> &s) {
return nvte_make_shape(s.data(), s.size());
}
/*! \brief Wrapped NVTETensor. */ /*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
...@@ -719,6 +756,12 @@ class QuantizationConfigWrapper { ...@@ -719,6 +756,12 @@ class QuantizationConfigWrapper {
&amax_epsilon, sizeof(float)); &amax_epsilon, sizeof(float));
} }
/*! \brief Set noop tensor pointer */
void set_noop_tensor(NVTETensor noop_tensor) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor,
sizeof(NVTETensor));
}
private: private:
/*! \brief Wrapped NVTEQuantizationConfig. */ /*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr; NVTEQuantizationConfig config_ = nullptr;
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
transformer_engine::is_fp8_dtype*; transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*; *transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore* *transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
}; };
local: *; local: *;
}; };
...@@ -39,6 +39,8 @@ Compute always in FP32 ...@@ -39,6 +39,8 @@ Compute always in FP32
namespace transformer_engine { namespace transformer_engine {
namespace normalization { namespace normalization {
bool& use_zero_centered_gamma_in_weight_dtype();
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
...@@ -213,9 +215,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -213,9 +215,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_ndim_scale_block = 1; _ndim_scale_block = 1;
} }
_scalar_dptr = std::make_unique<char[]>(typeToSize(wtype)); const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype;
_scalar_dptr = std::make_unique<char[]>(typeToSize(gamma_dtype));
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;); gamma_dtype, cpp_dtype,
*(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
_handle = cudnnExecutionPlanManager::Instance().GetHandle(); _handle = cudnnExecutionPlanManager::Instance().GetHandle();
...@@ -245,13 +250,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -245,13 +250,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.set_name("one") .set_name("one")
.set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(wtype)) .set_data_type(get_cudnn_fe_dtype(gamma_dtype))
.set_is_pass_by_value(true)); .set_is_pass_by_value(true));
auto centered_options = fe::graph::Pointwise_attributes() auto centered_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::ADD) .set_mode(fe::PointwiseMode_t::ADD)
.set_compute_data_type(get_cudnn_fe_dtype(ctype)); .set_compute_data_type(get_cudnn_fe_dtype(ctype));
_gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options);
_gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(gamma_dtype));
} else { } else {
_gamma = _gamma_zero; _gamma = _gamma_zero;
} }
...@@ -537,6 +542,18 @@ bool& _cudnn_norm_bwd_flag() { ...@@ -537,6 +542,18 @@ bool& _cudnn_norm_bwd_flag() {
bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
bool& _zero_centered_gamma_in_weight_dtype() {
#ifdef USE_ROCM
static bool flag = false;
return flag;
#else
static bool flag = transformer_engine::getenv<bool>("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE");
return flag;
#endif
}
bool& use_zero_centered_gamma_in_weight_dtype() { return _zero_centered_gamma_in_weight_dtype(); }
} // namespace normalization } // namespace normalization
} // namespace transformer_engine } // namespace transformer_engine
...@@ -559,3 +576,13 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { ...@@ -559,3 +576,13 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
#endif #endif
} }
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
#ifdef USE_ROCM
bool flag = false;
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = flag;
#else
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
#endif
}
...@@ -27,23 +27,28 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -27,23 +27,28 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const int multiprocessorCount, const bool zero_centered_gamma, const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) { cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) { !is_mxfp_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape); NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); NVTE_CHECK(gamma.data.dtype == beta.data.dtype,
"Gamma and Beta must have the same dtype. Gamma dtype: " +
to_string(gamma.data.dtype) + ", Beta dtype: " + to_string(beta.data.dtype));
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size.");
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");
NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(mu->data.dtype == DType::kFloat32); "Mu must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor.");
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
...@@ -59,11 +64,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -59,11 +64,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true; bool is_aligned = true;
#ifdef USE_ROCM #ifdef USE_ROCM
NVTE_CHECK( NVTE_CHECK(
!is_block_scaling(z->scaling_mode), !is_mxfp_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."); "Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#else #else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif #endif
if (cudnn_backend) { if (cudnn_backend) {
......
...@@ -23,19 +23,20 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -23,19 +23,20 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) { const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) { !is_mxfp_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1], "Gamma must have the same hidden size.");
NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");
NVTE_CHECK(z->data.shape == x.data.shape); NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
...@@ -49,11 +50,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -49,11 +50,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool is_aligned = true; bool is_aligned = true;
#ifdef USE_ROCM #ifdef USE_ROCM
NVTE_CHECK( NVTE_CHECK(
!is_block_scaling(z->scaling_mode), !is_mxfp_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."); "Cudnn backend is need by mxfp scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#else #else
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif #endif
bool training = bool training =
......
##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
cmake_minimum_required (VERSION 3.18)
project(nvshmemapi LANGUAGES CXX CUDA)
# Configure dependencies
find_package(CUDAToolkit REQUIRED)
# find_package(MPI REQUIRED)
set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation")
add_library(nvshmemapi STATIC nvshmem_waitkernel.cu)
set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE)
target_link_directories(nvshmemapi PUBLIC ${NVSHMEM_HOME}/lib)
target_link_libraries(nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver)
target_include_directories(nvshmemapi PRIVATE
${NVSHMEM_HOME}/include/)
target_include_directories(nvshmemapi PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
"${CMAKE_CURRENT_SOURCE_DIR}")
set_target_properties(nvshmemapi PROPERTIES
CUDA_STANDARD 17
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON)
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cuda_bf16.h>
#include <nvshmem.h>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>
#include "../util/logging.h"
#include "nvshmem_waitkernel.h"
__global__ void __launch_bounds__(1)
wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value,
uint64_t signal_reset) {
nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value);
*wait_flag = signal_reset;
}
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream) {
uint64_t wait_value = 1;
uint64_t signal_reset = 0;
cudaStream_t cur_stream = stream;
NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && wait_kind <= WaitKind::STREAM_WAIT,
"Invalid wait kind: ", static_cast<int>(wait_kind));
switch (wait_kind) {
case WaitKind::KERNEL_WAIT:
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
break;
case WaitKind::NVSHMEM_WAIT:
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
case WaitKind::STREAM_WAIT:
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
CU_STREAM_WAIT_VALUE_GEQ);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdint.h>
#endif
/*! \enum WaitKind
* \brief Types of wait operations that can be performed.
*/
enum class WaitKind {
KERNEL_WAIT = 0, /*!< Wait using a CUDA kernel */
NVSHMEM_WAIT = 1, /*!< Wait using NVSHMEM wait operation */
STREAM_WAIT = 2 /*!< Wait using CUDA stream synchronization */
};
/*! \brief Wait on a signal until a certain condition is met.
*
* \param[in] sig_addr The address of the signal to wait on.
* \param[in] wait_kind The kind of wait to perform.
* \param[in] stream The stream to wait on.
*/
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
...@@ -351,7 +351,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so ...@@ -351,7 +351,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
const transformer_engine::Tensor *input_fwd_cu = const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd); reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr), nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr), reinterpret_cast<T *>(output_cu->data.dptr),
...@@ -377,7 +377,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -377,7 +377,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const transformer_engine::Tensor *prob_cu = const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob); reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr), nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr), reinterpret_cast<T *>(output_cu->data.dptr),
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""This module provides predefined FP8 recipes.""" """This module provides predefined FP8 recipes."""
from __future__ import annotations from __future__ import annotations
import warnings import warnings
import os
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -81,6 +82,10 @@ class Recipe: ...@@ -81,6 +82,10 @@ class Recipe:
"""Whether the given recipe is per-tensor scaling.""" """Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling)) return isinstance(self, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
@dataclass() @dataclass()
class DelayedScaling(Recipe): class DelayedScaling(Recipe):
...@@ -287,3 +292,99 @@ class MXFP8BlockScaling(Recipe): ...@@ -287,3 +292,99 @@ class MXFP8BlockScaling(Recipe):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
@dataclass()
class Float8BlockScaling(Recipe):
"""
Use block-wise scaling for FP8 tensors.
In this strategy, tensors are scaled in blockwise fashion. Values within
each block share a common scaling factor. The block dimensionality
can be configured. The scaling factors are float32 containers. They
will by default be constrained to powers of 2.
Since the scaling happens in a particular direction (either rowwise
or columnwise), the quantized tensor and its transpose are not numerically
equivalent. Due to this, when Transformer Engine needs both the FP8 tensor
and its transpose (e.g. to calculate both forward and backward pass),
during the quantization both versions are computed from the high precision
input to avoid double quantization errors.
NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
"""
use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1"
fp8_format: Format = Format.E4M3
fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
x_block_scaling_dim: int = 1
w_block_scaling_dim: int = 2
grad_block_scaling_dim: int = 1
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x"
assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w"
assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad"
assert not (
self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert not (
self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert not (
self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop."
assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad."
assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad."
def __repr__(self) -> str:
return (
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, "
f"x_block_scaling_dim={self.x_block_scaling_dim}, "
f"w_block_scaling_dim={self.w_block_scaling_dim}, "
f"grad_block_scaling_dim={self.grad_block_scaling_dim}, "
f"fp8_gemm_fprop={self.fp8_gemm_fprop}, "
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
...@@ -156,7 +156,8 @@ namespace { ...@@ -156,7 +156,8 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales, const float max_fp8, const bool force_pow_2_scales,
const float epsilon) { const float epsilon) {
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
std::numeric_limits<float>::max());
} }
} // namespace } // namespace
......
...@@ -7,19 +7,21 @@ ...@@ -7,19 +7,21 @@
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include <limits> #include "common/common.h"
namespace transformer_engine { namespace transformer_engine {
__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8,
bool force_pow_2_scales, float epsilon) { bool force_pow_2_scales, float epsilon,
float value_for_inf) {
// NOTE: NAN amax evaluates false for <, handled further down.
if (amax < epsilon) { if (amax < epsilon) {
amax = epsilon; amax = epsilon;
} }
float scale = 1.f; float scale = 1.f;
if (isinf(amax) || amax == 0.f) { if (isinf(amax) || amax == 0.f || isnan(amax)) {
return scale; return scale;
} }
...@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f ...@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
// the scale is not representable in FP32. // the scale is not representable in FP32.
if (isinf(scale)) { if (isinf(scale)) {
// use fp32 max to represent the scale // use fp32 max to represent the scale
scale = std::numeric_limits<float>::max(); scale = value_for_inf;
} }
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) { if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale); uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000; scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error. // If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0); __builtin_assume(scale_bits != 0 || scale == 0.0);
__builtin_assume(scale_bits != 0x80000000); __builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits); scale = *reinterpret_cast<float *>(&scale_bits);
} }
...@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f ...@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
return scale; return scale;
} }
// Calculate the quantization scale for an individual data element
// given the amax(abs(tile)) value for a given quantization tile.
//
//
// Arguments:
// IType: data type of the tensor being quantized (float or bf16)
// OType: quantized data type (e4m3 or e5m2)
// amax: The evaluation of amax(abs(tile)) for the quantization tile.
// eps: An epsilon used as a floor for amax.
// pow_2_scaling: Whether to force the scale to be a power of 2.
template <typename IType, typename OType>
__device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps,
const float pow_2_scaling) {
constexpr float fp8_max = TypeInfo<OType>::max_finite_value;
// NOTE: We're relying on compute_scale_from_amax to have behavior where it
// clips the mantissa of the max_finite_value if power of 2 scaling applies.
constexpr float value_for_inf = TypeInfo<IType>::max_finite_value;
return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf);
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
...@@ -211,53 +211,32 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -211,53 +211,32 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype()); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
} }
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret;
if (ndim == 0) {
ret.data = nullptr;
ret.ndim = 0;
return ret;
}
NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")");
std::copy(data, data + ndim, ret.owned_data);
ret.data = ret.owned_data;
ret.ndim = ndim;
return ret;
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
NVTEShape ret;
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) { std::vector<size_t> shape = t.shape();
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
}
ret.data = shape_cache.data();
ret.ndim = shape_cache.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}
return ret; return nvte_make_shape(shape.data(), shape.size());
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
...@@ -265,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { ...@@ -265,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size());
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
} }
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
...@@ -292,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { ...@@ -292,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
size_t nvte_tensor_element_size(const NVTETensor tensor) { size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float); if (tensor == nullptr) return sizeof(float);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype); return transformer_engine::typeToSize(t.dtype());
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
...@@ -336,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { ...@@ -336,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
} }
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0}; if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
ret.data = t.scale_inv.shape.data();
ret.ndim = t.scale_inv.shape.size();
return ret;
} }
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
...@@ -463,6 +438,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -463,6 +438,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigAmaxEpsilon: case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size); std::memcpy(buf, &config_.amax_epsilon, attr_size);
break; break;
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
...@@ -492,6 +470,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -492,6 +470,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigAmaxEpsilon: case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size); std::memcpy(&config_.amax_epsilon, buf, attr_size);
break; break;
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
......
...@@ -23,6 +23,42 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType ...@@ -23,6 +23,42 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
cudaStream_t stream); cudaStream_t stream);
void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
// No rowwise data
NONE,
// Rowwise data, scales in GEMM format
ROWWISE
// TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption {
// No columnwise data
NONE,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
};
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scale, cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cuda/barrier>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \
(defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900)
#define TMA_HW_SUPPORTED
#endif
namespace transformer_engine {
namespace {
// const values configuration
constexpr size_t kThreadsPerWarp = 32;
#ifdef TMA_HW_SUPPORTED
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 32;
constexpr size_t WARP_TILE_DIM_Y = 64;
constexpr size_t THREAD_TILE_DIM_X = 16;
constexpr size_t THREAD_TILE_DIM_Y = 4;
#else
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 64;
constexpr size_t WARP_TILE_DIM_Y = 32;
constexpr size_t THREAD_TILE_DIM_X = 8;
constexpr size_t THREAD_TILE_DIM_Y = 8;
#endif
#ifdef TMA_HW_SUPPORTED
constexpr size_t NUM_BYTES_PER_BANK = 4;
constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK;
constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM;
constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS =
BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM);
constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK;
#endif
constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y;
constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD;
constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X;
constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y;
constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK;
constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X;
constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP;
#define MIN(a, b) (a < b ? a : b)
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c,
OType* const output_t, CType* const tile_scales_inv_c,
CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
const __grid_constant__ CUtensorMap tensor_map_output_t,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK;
// This is ONLY true if the input is a full tile
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_idx =
tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
// Step 1: Load a block tile of input data into thread tiles on registers
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length);
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
}
__syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
}
// Step 4: store transpose into shared memory
if constexpr (kReturnTranspose) {
#ifdef TMA_HW_SUPPORTED
__shared__ alignas(128)
OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS];
OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] =
reinterpret_cast<OType(*)[BLOCK_TILE_DIM]>(block_tile_trans_shared);
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_X; i++) {
auto warp_id_in_block_x_ = warp_id_in_block_y;
auto warp_id_in_block_y_ = warp_id_in_block_x;
int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP +
tid_in_warp_x * THREAD_TILE_DIM_X + i;
int col_idx =
warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y;
block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i];
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Step 5: store transpose output
// Initiate TMA transfer to copy shared memory to global memory
if (threadIdx.x == 0) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t*>(&tensor_map_output_t), tile_id_y * BLOCK_TILE_DIM,
tile_id_x * BLOCK_TILE_DIM,
reinterpret_cast<uint64_t*>(block_tile_trans_shared_otype_ptr));
// Wait for TMA transfer to have finished reading shared memory.
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for the group to have completed reading from shared memory.
ptx::cp_async_bulk_wait_group_read<0>();
}
#else
// Step 4 Alternative (when TMA is not available, skip writing to shared memory)
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_X; i++) {
thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows);
}
#endif
}
}
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK;
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM;
const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM;
const size_t block_tile_start_idx =
block_tile_start_row_idx * row_length + block_tile_start_col_idx;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
// handle non-full tile
// check for three cases: full thread tile, nonfull thread tile, empty thread tile
// for empty thread tile, directly write zero to the transposed shared mem buffer
// for nonfull thread tile, fill zero to thread tile and act as if it's full
const size_t thread_tile_start_row_idx =
tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP +
tid_in_warp_y * THREAD_TILE_DIM_Y;
const size_t thread_tile_start_col_idx =
tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP +
tid_in_warp_x * THREAD_TILE_DIM_X;
const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1;
const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1;
bool full_thrd_tile =
(thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length);
bool empty_thrd_tile =
(thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length);
bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile);
const size_t thread_tile_ncols =
MIN(THREAD_TILE_DIM_X,
(MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1));
const size_t thread_tile_nrows =
MIN(THREAD_TILE_DIM_Y,
(MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1));
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
if (!empty_thrd_tile) {
// Step 1: Load a block tile of input data into thread tiles on registers
// Edge case: nonfull thread tile case, will use the partial load function here
if (nonfull_thrd_tile) {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
thrd_tile_input[i].clear();
} else {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
}
} else {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
THREAD_TILE_DIM_X);
}
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
}
__syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
// for nonfull thread tile, pay attention when saving tmp_output_c to global
// memory, cannot vec store_to, but need to elt store to for empty tile,
// it should not enter this step, skip to Step 4
// set thrd_tile_out_trans to all zero
if constexpr (kReturnTranspose) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
thrd_tile_out_trans[j].clear();
}
}
if (!empty_thrd_tile) {
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
continue;
}
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
if constexpr (kReturnTranspose) {
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < thread_tile_ncols; i++) {
thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0,
thread_tile_nrows);
}
}
}
}
template <typename OutputType>
CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) {
CUtensorMapDataType dataType;
if constexpr (std::is_same_v<OutputType, __nv_fp8_e4m3> ||
std::is_same_v<OutputType, __nv_fp8_e5m2>) {
dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else {
NVTE_CHECK(false, "Invalid Output type (must be FP8).");
}
CUtensorMap tensor_map_output_trans{};
create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x,
/*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM,
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType));
return tensor_map_output_trans;
}
} // namespace
} // namespace transformer_engine
namespace transformer_engine::detail {
void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
}
NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");
size_t scale_k = scale_inv.shape[1];
const size_t scale_stride_x = 1;
const size_t scale_stride_y = scale_k;
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions.");
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
if (full_tile) {
CUtensorMap tensor_map_output_trans;
if (return_transpose) {
tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length);
}
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale);
} else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale);
} // full-tile
) // return_transpose
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine::detail
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
using transformer_engine::detail::FP8BlockwiseRowwiseOption;
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr size_t kThreadsPerWarp = 32;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace transformer_engine
namespace transformer_engine::detail {
void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
// Options for scale layout of cuBLAS GEMM kernel.
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE,
"Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
}
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1];
scale_t_stride_y = 1;
}
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
size_t smem_bytes = kSMemSize * sizeof(InputType);
// shared memory must be requested up
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine::detail
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment