Unverified Commit 5350f277 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Remove unneccessary MXFP8 scale_inv padding (#1954)



* remove unnecessary padding
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* adapt the test_distributed_layernorm byte count
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ed75c2b0
...@@ -75,8 +75,6 @@ class TestDistributedLayernorm: ...@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
) )
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling(): if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
......
...@@ -33,7 +33,6 @@ from ..quantize import ( ...@@ -33,7 +33,6 @@ from ..quantize import (
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
remove_padding_from_scale_inv,
) )
from .misc import get_padded_spec from .misc import get_padded_spec
...@@ -399,7 +398,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -399,7 +398,6 @@ class GemmPrimitive(BasePrimitive):
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
) )
lhs_scale_inv = apply_padding_to_scale_inv( lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv, lhs_scale_inv,
scaling_mode, scaling_mode,
...@@ -885,16 +883,6 @@ def gemm_uses_jax_dot() -> bool: ...@@ -885,16 +883,6 @@ def gemm_uses_jax_dot() -> bool:
return not GemmPrimitive.enabled() return not GemmPrimitive.enabled()
def _get_scale_inv_without_padding(scaled_tensor):
return remove_padding_from_scale_inv(
scaled_tensor.scale_inv,
scaled_tensor.scaling_mode,
scaled_tensor.data.shape,
is_colwise=scaled_tensor.is_colwise,
flatten_axis=scaled_tensor.flatten_axis,
)
def _te_gemm( def _te_gemm(
lhs: Union[jax.Array, ScaledTensor], lhs: Union[jax.Array, ScaledTensor],
rhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor],
...@@ -909,6 +897,7 @@ def _te_gemm( ...@@ -909,6 +897,7 @@ def _te_gemm(
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
lhs_data = lhs lhs_data = lhs
rhs_data = rhs rhs_data = rhs
...@@ -933,7 +922,7 @@ def _te_gemm( ...@@ -933,7 +922,7 @@ def _te_gemm(
lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor()
scaling_mode = lhs_q.scaling_mode scaling_mode = lhs_q.scaling_mode
lhs_data = lhs_q.data lhs_data = lhs_q.data
lhs_scale_inv = _get_scale_inv_without_padding(lhs_q) lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T": if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
...@@ -951,7 +940,7 @@ def _te_gemm( ...@@ -951,7 +940,7 @@ def _te_gemm(
f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
) )
rhs_data = rhs_q.data rhs_data = rhs_q.data
rhs_scale_inv = _get_scale_inv_without_padding(rhs_q) rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T": if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
...@@ -1230,10 +1219,6 @@ def _jax_gemm_mxfp8_1d( ...@@ -1230,10 +1219,6 @@ def _jax_gemm_mxfp8_1d(
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)])
rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)])
# JAX scaled_matmul only supports NT now (TN-gemm) # JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape: # * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_data (B, M, K) * rhs_data (B, N, K)
......
...@@ -121,9 +121,6 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -121,9 +121,6 @@ class BlockScaleDequantizer(Dequantizer):
scale_shape = scaling_mode.get_scale_shape( scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale_inv = jax.lax.slice(
scale_inv, [0] * len(scale_shape), scale_shape
) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[: flatten_axis - 1], *data_shape[: flatten_axis - 1],
...@@ -211,28 +208,38 @@ def _grouped_dequantize(grouped_scaled_tensor): ...@@ -211,28 +208,38 @@ def _grouped_dequantize(grouped_scaled_tensor):
f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
f" {data_i.size}" f" {data_i.size}"
) )
scale_shape_i = scaling_mode.get_scale_shape( padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, grouped_scaled_tensor.is_colwise,
is_padded=True, is_padded=True,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
scale_shape_i_size = math.prod(scale_shape_i) unpadded_scale_shape_i = scaling_mode.get_scale_shape(
scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size] data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=False,
flatten_axis=flatten_axis,
)
scale_inv_i = scale_inv[
scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)
].reshape(padded_scale_shape_i)
scale_inv_i = jax.lax.slice(
scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i
)
dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
if len(data_i) == 0: if len(data_i) == 0:
out_i = [] out_i = []
else: else:
out_i = dequantizer_type._dequantize_func( out_i = dequantizer_type._dequantize_func(
data_i.reshape(data_shape_i), data_i.reshape(data_shape_i),
scale_inv_i.reshape(scale_shape_i), scale_inv_i,
grouped_scaled_tensor.dq_dtype, grouped_scaled_tensor.dq_dtype,
scaling_mode=grouped_scaled_tensor.scaling_mode, scaling_mode=grouped_scaled_tensor.scaling_mode,
is_colwise=grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
flatten_axis=grouped_scaled_tensor.flatten_axis, flatten_axis=grouped_scaled_tensor.flatten_axis,
) )
output.append(out_i) output.append(out_i)
scale_inv_ptr += scale_shape_i_size scale_inv_ptr += math.prod(padded_scale_shape_i)
return output return output
......
...@@ -17,7 +17,6 @@ from jax.tree_util import register_pytree_node_class ...@@ -17,7 +17,6 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .helper import apply_padding_to_scale_inv
from .scaling_modes import ScalingMode, TensorUsage from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import ( from ..sharding import (
...@@ -135,15 +134,17 @@ class ScaledTensor1x(ScaledTensor): ...@@ -135,15 +134,17 @@ class ScaledTensor1x(ScaledTensor):
if self.scaling_mode == ScalingMode.NO_SCALING: if self.scaling_mode == ScalingMode.NO_SCALING:
self.scale_inv = jnp.empty((0,), dtype=jnp.float32) self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
else: else:
self.scale_inv = apply_padding_to_scale_inv( unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.scale_inv,
self.scaling_mode,
self.data.shape, self.data.shape,
is_colwise=self.is_colwise, is_colwise=self.is_colwise,
is_padded=False,
flatten_axis=self.flatten_axis, flatten_axis=self.flatten_axis,
) )
assert self.scale_inv.shape == unpadded_scale_shape, (
"Unpadded inverse scale factor has wrong shape, expected"
f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
)
def tree_flatten(self): def tree_flatten(self):
"""Flattens the tensor for JAX tree operations. """Flattens the tensor for JAX tree operations.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment