Unverified Commit c9e8e305 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Removes unneccessary reshapes for FP8 GEMM (#1820)



* removes unnecessary reshapes for FP8 GEMM

* use nn.jax.scaled_matmul
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 355c4e42
...@@ -142,59 +142,30 @@ def _calculate_remaining_shape(shape, contracting_dims): ...@@ -142,59 +142,30 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
# Apply jit to guarantee correctness of FP8 GEMM. def _transpose_contract_dims(ndim, contracting_dims):
@partial( return tuple(ndim - i - 1 for i in contracting_dims)[::-1]
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_fp8 = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=jnp.float32
)
scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)
return (out_fp8 * scale_inv).astype(lhs.dq_dtype)
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
precision = ( out_fp8 = jax.lax.dot_general(
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=jnp.float32
) )
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision) scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)
# Reshape [B, M, N] -> [..., M, N] return (out_fp8 * scale_inv).astype(lhs.dq_dtype)
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d( def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
): ):
...@@ -204,7 +175,6 @@ def _jax_gemm_mxfp8_1d( ...@@ -204,7 +175,6 @@ def _jax_gemm_mxfp8_1d(
assert ( assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode" ), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
...@@ -235,7 +205,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -235,7 +205,7 @@ def _jax_gemm_mxfp8_1d(
# * Expected shape: # * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper( out_3d = jax.nn.scaled_matmul(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
) )
# Reshape [1, reduce(..., M), N] -> [..., M, N] # Reshape [1, reduce(..., M), N] -> [..., M, N]
...@@ -262,9 +232,16 @@ def _jax_gemm( ...@@ -262,9 +232,16 @@ def _jax_gemm(
dim_nums = (contracting_dims, ((), ())) dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling(): if lhs.scaling_mode.is_tensor_scaling():
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums) assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment