Unverified Commit ac4e0fd6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Rework amax reduction over TPSP (#2218)



* rm using_global_amax_of_x
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7022d50f
...@@ -551,7 +551,10 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -551,7 +551,10 @@ class AmaxCalculationPrimitive(BasePrimitive):
name = "jax_local_amax" name = "jax_local_amax"
multiple_results = False multiple_results = False
impl_static_args = (1,) # amax_scope impl_static_args = (
1,
2,
) # amax_scope, batch_sequence_transpose
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -560,11 +563,12 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -560,11 +563,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval, x_aval,
*, *,
amax_scope, amax_scope,
batch_sequence_transpose,
): ):
""" """
amax calcuation abstract amax calcuation abstract
""" """
del amax_scope del amax_scope, batch_sequence_transpose
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -576,17 +580,19 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -576,17 +580,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl( def impl(
x, x,
amax_scope, amax_scope,
batch_sequence_transpose,
): ):
""" """
amax calcuation implementation amax calcuation implementation
""" """
del amax_scope del amax_scope, batch_sequence_transpose
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax return amax
@staticmethod @staticmethod
def infer_sharding_from_operands( def infer_sharding_from_operands(
amax_scope, amax_scope,
batch_sequence_transpose,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -594,7 +600,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -594,7 +600,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
""" """
amax calcuation infer_sharding_from_operands amax calcuation infer_sharding_from_operands
""" """
del (amax_scope, arg_infos, result_infos) # Unused. del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None), PartitionSpec(None),
...@@ -605,6 +611,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -605,6 +611,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod @staticmethod
def partition( def partition(
amax_scope, amax_scope,
batch_sequence_transpose,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -613,25 +620,26 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -613,25 +620,26 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax calcuation partition amax calcuation partition
""" """
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None), PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding", desc="AmaxCalculation.amax_sharding",
) )
def sharded_impl(x): def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl( amax = AmaxCalculationPrimitive.impl(
x, x,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
) )
if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP
gmesh = global_mesh_resource() gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP: # Run AR across FSDP if amax_scope is AmaxScope.FSDP:
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax return amax
...@@ -640,11 +648,11 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -640,11 +648,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types):
""" """
amax calcuation shardy_sharding_rule amax calcuation shardy_sharding_rule
""" """
del amax_scope, mesh, result_types del amax_scope, batch_sequence_transpose, mesh, result_types
prefix = "AmaxCal" prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",) output_spec = (f"{prefix}_amax",)
...@@ -701,6 +709,7 @@ def _quantize_dbias_impl( ...@@ -701,6 +709,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -745,6 +754,8 @@ def _quantize_dbias_impl( ...@@ -745,6 +754,8 @@ def _quantize_dbias_impl(
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
) )
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
...@@ -760,6 +771,7 @@ def _quantize_dbias_impl( ...@@ -760,6 +771,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind( amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data, x.data,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
...@@ -833,6 +845,7 @@ def quantize( ...@@ -833,6 +845,7 @@ def quantize(
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -853,6 +866,7 @@ def quantize( ...@@ -853,6 +866,7 @@ def quantize(
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
) )
return out return out
...@@ -863,6 +877,7 @@ def quantize_dbias( ...@@ -863,6 +877,7 @@ def quantize_dbias(
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -889,6 +904,7 @@ def quantize_dbias( ...@@ -889,6 +904,7 @@ def quantize_dbias(
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
) )
......
...@@ -67,7 +67,6 @@ def dense( ...@@ -67,7 +67,6 @@ def dense(
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
...@@ -86,7 +85,6 @@ def dense( ...@@ -86,7 +85,6 @@ def dense(
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes. collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
...@@ -109,14 +107,13 @@ def dense( ...@@ -109,14 +107,13 @@ def dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
using_global_amax_of_x,
collective_op_set, collective_op_set,
quantizer_set, quantizer_set,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
def _dense( def _dense(
x, x,
kernel, kernel,
...@@ -126,7 +123,6 @@ def _dense( ...@@ -126,7 +123,6 @@ def _dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
using_global_amax_of_x,
collective_op_set, collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management quantizer_set, # need to be a diff_arg for DelayedScaling state management
): ):
...@@ -144,7 +140,6 @@ def _dense( ...@@ -144,7 +140,6 @@ def _dense(
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes. collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
...@@ -160,7 +155,6 @@ def _dense( ...@@ -160,7 +155,6 @@ def _dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
using_global_amax_of_x,
collective_op_set, collective_op_set,
quantizer_set, quantizer_set,
) )
...@@ -176,7 +170,6 @@ def _dense_fwd_rule( ...@@ -176,7 +170,6 @@ def _dense_fwd_rule(
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
using_global_amax_of_x,
collective_op_set, collective_op_set,
quantizer_set, quantizer_set,
): ):
...@@ -203,7 +196,8 @@ def _dense_fwd_rule( ...@@ -203,7 +196,8 @@ def _dense_fwd_rule(
x, x,
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -250,7 +244,6 @@ def _dense_bwd_rule( ...@@ -250,7 +244,6 @@ def _dense_bwd_rule(
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
using_global_amax_of_x,
collective_op_set, collective_op_set,
ctx, ctx,
grad, grad,
...@@ -280,7 +273,8 @@ def _dense_bwd_rule( ...@@ -280,7 +273,8 @@ def _dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
) )
# GEMM NT # GEMM NT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment