".github/vscode:/vscode.git/clone" did not exist on "fb2a2463982143a05e67890643b5eb263e07b92b"
Unverified Commit 8c364b4d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[Common][JAX] Improve error message for cublas fp8 gemm with incorrect shape (#2261)



* Improve error message for cublas fp8 gemm with incorrect shape
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Removed unnecessary non-contracting size check
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* rename inner dim -> leading dim
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 8eec2004
...@@ -141,6 +141,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -141,6 +141,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
} }
} }
if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.lda % 16 == 0,
"Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
}
} else if (nvfp4) { } else if (nvfp4) {
// NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.
...@@ -187,7 +193,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -187,7 +193,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.lda % 16) == 0, NVTE_CHECK((ret.lda % 16) == 0,
"Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode. // Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK((m % 8) == 0, NVTE_CHECK((m % 8) == 0,
...@@ -216,6 +222,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -216,6 +222,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
} }
} }
if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.ldb % 16 == 0,
"Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
}
} else if (nvfp4) { } else if (nvfp4) {
if (is_B_transposed) { if (is_B_transposed) {
NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
......
...@@ -360,6 +360,28 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): ...@@ -360,6 +360,28 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise):
return swizzled.reshape(original_shape) return swizzled.reshape(original_shape)
def get_lhs_axis_boundary(lhs_cdims, is_transposed):
"""Get the axis boundary for the LHS operand."""
return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims)
def get_rhs_axis_boundary(rhs_cdims, is_transposed):
"""Get the axis boundary for the RHS operand."""
return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1
def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name):
"""Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM."""
if scaling_mode != ScalingMode.NO_SCALING:
# Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
alignment = 32 if scaling_mode.is_nvfp4_scaling else 16
assert contracting_size % alignment == 0, (
f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of"
f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}"
)
class GemmPrimitive(BasePrimitive): class GemmPrimitive(BasePrimitive):
""" """
Primitive for cuBLAS GEMM Primitive for cuBLAS GEMM
...@@ -452,6 +474,29 @@ class GemmPrimitive(BasePrimitive): ...@@ -452,6 +474,29 @@ class GemmPrimitive(BasePrimitive):
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
) )
lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs.shape[lhs_axis_boundary:])
if lhs_is_transposed
else reduce(operator.mul, lhs.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs.shape[:rhs_axis_boundary])
if rhs_is_transposed
else reduce(operator.mul, rhs.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
# Determine output shape and dtype # Determine output shape and dtype
assert ( assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1 dtypes.canonicalize_dtype(out_dtype).itemsize > 1
...@@ -563,8 +608,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -563,8 +608,8 @@ class GemmPrimitive(BasePrimitive):
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = { kwargs = {
"scaling_mode": int(scaling_mode.value), "scaling_mode": int(scaling_mode.value),
"lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), "lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed),
"rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, "rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed),
"lhs_transposed": lhs_transposed, "lhs_transposed": lhs_transposed,
"rhs_transposed": rhs_transposed, "rhs_transposed": rhs_transposed,
"fuse_bias": fuse_bias, "fuse_bias": fuse_bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment