Unverified Commit 355c4e42 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] FP8 GEMM via dot_general + direct quant (#1819)



* fp8 gemm with direct quant
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4732ed76
......@@ -142,10 +142,6 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
......@@ -156,27 +152,25 @@ def _dequantize(x, scale_inv, dq_dtype):
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
out_fp8 = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=jnp.float32
)
return out_3d
scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)
return (out_fp8 * scale_inv).astype(lhs.dq_dtype)
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
"""FP8 GEMM"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment