Unverified Commit 0e45e138 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

Revert "[JAX] Removes unneccessary reshapes for FP8 GEMM (#1740)" (#1774)

This reverts commit 5bee81e2

.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a94e5dde
...@@ -142,36 +142,65 @@ def _calculate_remaining_shape(shape, contracting_dims): ...@@ -142,36 +142,65 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype): def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype) return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM. # Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3)) @partial(
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching # Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
"""FP8 GEMM for XLA pattern match"""
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype) lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype) rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = _transpose_contract_dims(lhs_dq.ndim, lhs_contract) lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = _transpose_contract_dims(rhs_dq.ndim, rhs_contract) rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
return jax.lax.dot_general( precision = (
lhs_dq, rhs_dq, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
) )
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d( def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
): ):
...@@ -181,6 +210,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -181,6 +210,7 @@ def _jax_gemm_mxfp8_1d(
assert ( assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode" ), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
...@@ -211,7 +241,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -211,7 +241,7 @@ def _jax_gemm_mxfp8_1d(
# * Expected shape: # * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = jax.nn.scaled_matmul( out_3d = scaled_matmul_wrapper(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
) )
# Reshape [1, reduce(..., M), N] -> [..., M, N] # Reshape [1, reduce(..., M), N] -> [..., M, N]
...@@ -238,16 +268,9 @@ def _jax_gemm( ...@@ -238,16 +268,9 @@ def _jax_gemm(
dim_nums = (contracting_dims, ((), ())) dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling(): if lhs.scaling_mode.is_tensor_scaling():
assert ( return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment