Unverified Commit 2ac3c168 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Defer TE/JAX cublas shape check on fp8 gemms until lowering (#2292)



Defer cublas check on fp8 gemms until lowering
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 818b30cc
...@@ -470,29 +470,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -470,29 +470,6 @@ class GemmPrimitive(BasePrimitive):
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
) )
lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs.shape[lhs_axis_boundary:])
if lhs_is_transposed
else reduce(operator.mul, lhs.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs.shape[:rhs_axis_boundary])
if rhs_is_transposed
else reduce(operator.mul, rhs.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
# Determine output shape and dtype # Determine output shape and dtype
assert ( assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1 dtypes.canonicalize_dtype(out_dtype).itemsize > 1
...@@ -601,6 +578,29 @@ class GemmPrimitive(BasePrimitive): ...@@ -601,6 +578,29 @@ class GemmPrimitive(BasePrimitive):
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
) )
lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
if lhs_transposed
else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
if rhs_transposed
else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = { kwargs = {
"scaling_mode": int(scaling_mode.value), "scaling_mode": int(scaling_mode.value),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment