Unverified Commit ee384ab5 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

Make `CanonicalizeGemmInput()` support non-TN layout FP8 GEMM on Blackwell...


Make `CanonicalizeGemmInput()` support non-TN layout FP8 GEMM on Blackwell with column-wise/transposed data (#2233)

Modified CanonicalizeGemmInput() logic to pull from column-wise data for FP8 GEMM on Blackwell when row-wise is not available.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent c593bcef
...@@ -140,6 +140,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -140,6 +140,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else { } else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
} }
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
"Input A is missing column-wise usage");
ret.A = A.columnwise_data.dptr;
ret.transA = is_A_transposed ? CUBLAS_OP_N : CUBLAS_OP_T;
ret.Atype = A.columnwise_data.dtype;
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
ret.lda = is_A_transposed ? m : k;
} }
if (is_fp8_dtype(ret.Atype)) { if (is_fp8_dtype(ret.Atype)) {
...@@ -221,6 +231,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -221,6 +231,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else { } else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
} }
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
"Input B is missing column-wise usage");
ret.B = B.columnwise_data.dptr;
ret.transB = is_B_transposed ? CUBLAS_OP_N : CUBLAS_OP_T;
ret.Btype = B.columnwise_data.dtype;
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
ret.ldb = is_B_transposed ? k : n;
} }
if (is_fp8_dtype(ret.Atype)) { if (is_fp8_dtype(ret.Atype)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment