"examples/jax/vscode:/vscode.git/clone" did not exist on "bc99a88da65fa2e47a0eadff575b456bd4ec02e1"
Unverified Commit 355c4e42 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] FP8 GEMM via dot_general + direct quant (#1819)



* fp8 gemm with direct quant
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4732ed76
...@@ -142,10 +142,6 @@ def _calculate_remaining_shape(shape, contracting_dims): ...@@ -142,10 +142,6 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM. # Apply jit to guarantee correctness of FP8 GEMM.
@partial( @partial(
jax.jit, jax.jit,
...@@ -156,27 +152,25 @@ def _dequantize(x, scale_inv, dq_dtype): ...@@ -156,27 +152,25 @@ def _dequantize(x, scale_inv, dq_dtype):
), ),
) )
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose # Reshape + Transpose
# [..., M, K] -> [B, M, K] # [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K] # [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N") lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,))) dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general( out_fp8 = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=jnp.float32
) )
return out_3d scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)
return (out_fp8 * scale_inv).astype(lhs.dq_dtype)
def _jax_gemm_tensor_scaling_fp8( def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
): ):
"""FP8 GEMM for XLA pattern match""" """FP8 GEMM"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode" assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment