Unverified Commit 0e45e138 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

Revert "[JAX] Removes unneccessary reshapes for FP8 GEMM (#1740)" (#1774)

This reverts commit 5bee81e2

.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a94e5dde
......@@ -142,36 +142,65 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
"""FP8 GEMM for XLA pattern match"""
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
lhs_contract = _transpose_contract_dims(lhs_dq.ndim, lhs_contract)
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.data_layout == "T":
rhs_contract = _transpose_contract_dims(rhs_dq.ndim, rhs_contract)
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
return jax.lax.dot_general(
lhs_dq, rhs_dq, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
......@@ -181,6 +210,7 @@ def _jax_gemm_mxfp8_1d(
assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......@@ -211,7 +241,7 @@ def _jax_gemm_mxfp8_1d(
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = jax.nn.scaled_matmul(
out_3d = scaled_matmul_wrapper(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
......@@ -238,16 +268,9 @@ def _jax_gemm(
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling():
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment