Unverified Commit 9d4e11ea authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug GEMM refactor (#1652)



* Minor stylistic tweaks and typo fixes

Review suggestions from @ptrendx
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix incorrect col strides for MXFP8 matrices
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ba5dc5dd
...@@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
const int arch = cuda::sm_arch(); const int arch = cuda::sm_arch();
// Transpose mode with column-major ordering // Transpose mode with column-major ordering
bool transa_bool = transA == CUBLAS_OP_T; bool is_A_transposed = transA == CUBLAS_OP_T;
bool transb_bool = transB == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T;
// Configure A matrix // Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) { if (is_tensor_scaling(A.scaling_mode)) {
...@@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.transA = transA; ret.transA = transA;
ret.Atype = A.data.dtype; ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr; ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = transa_bool ? k : m; ret.lda = is_A_transposed ? k : m;
if (arch < 100 && !transa_bool) { if (arch < 100 && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr; ret.A = A.columnwise_data.dptr;
...@@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// MXFP8 // MXFP8
// Note: Row-wise and column-wise data are scaled along different // Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order). // dimensions (with matrix interpreted in row-major order).
if (transa_bool) { if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else { } else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
} }
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = transA; ret.transA = transA;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = m; ret.lda = is_A_transposed ? k : m;
} else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling // FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transa_bool) { if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else { } else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
} }
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T; ret.transA = CUBLAS_OP_T;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k; ret.lda = k;
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
...@@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.transB = transB; ret.transB = transB;
ret.Btype = B.data.dtype; ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr; ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = transb_bool ? n : k; ret.ldb = is_B_transposed ? n : k;
if (arch < 100 && transb_bool) { if (arch < 100 && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr; ret.B = B.columnwise_data.dptr;
...@@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// MXFP8 // MXFP8
// Note: Row-wise and column-wise data are scaled along different // Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order). // dimensions (with matrix interpreted in row-major order).
if (transb_bool) { if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else { } else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
} }
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = transB; ret.transB = transB;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k; ret.ldb = is_B_transposed ? n : k;
} else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling // FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transb_bool) { if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else { } else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
} }
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N; ret.transB = CUBLAS_OP_N;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k; ret.ldb = k;
// Requirements from // Requirements from
...@@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment