Unverified Commit ac4e0fd6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Rework amax reduction over TPSP (#2218)



* rm using_global_amax_of_x
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7022d50f
......@@ -551,7 +551,10 @@ class AmaxCalculationPrimitive(BasePrimitive):
name = "jax_local_amax"
multiple_results = False
impl_static_args = (1,) # amax_scope
impl_static_args = (
1,
2,
) # amax_scope, batch_sequence_transpose
inner_primitive = None
outer_primitive = None
......@@ -560,11 +563,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval,
*,
amax_scope,
batch_sequence_transpose,
):
"""
amax calcuation abstract
"""
del amax_scope
del amax_scope, batch_sequence_transpose
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -576,17 +580,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl(
x,
amax_scope,
batch_sequence_transpose,
):
"""
amax calcuation implementation
"""
del amax_scope
del amax_scope, batch_sequence_transpose
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
batch_sequence_transpose,
mesh,
arg_infos,
result_infos,
......@@ -594,7 +600,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, arg_infos, result_infos) # Unused.
del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
......@@ -605,6 +611,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod
def partition(
amax_scope,
batch_sequence_transpose,
mesh,
arg_infos,
result_infos,
......@@ -613,25 +620,26 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
desc="AmaxCalculation.amax_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
)
if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh)
sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
if amax_scope is AmaxScope.FSDP: # Run AR across FSDP
gmesh = global_mesh_resource()
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
......@@ -640,11 +648,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, mesh, value_types, result_types):
def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, mesh, result_types
del amax_scope, batch_sequence_transpose, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
......@@ -701,6 +709,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -745,6 +754,8 @@ def _quantize_dbias_impl(
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
......@@ -760,6 +771,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......@@ -833,6 +845,7 @@ def quantize(
quantizer: Quantizer,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -853,6 +866,7 @@ def quantize(
quantizer=quantizer,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
)
return out
......@@ -863,6 +877,7 @@ def quantize_dbias(
is_dbias: bool = True,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -889,6 +904,7 @@ def quantize_dbias(
is_dbias=is_dbias,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
)
......
......@@ -67,7 +67,6 @@ def dense(
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
......@@ -86,7 +85,6 @@ def dense(
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
......@@ -109,14 +107,13 @@ def dense(
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
def _dense(
x,
kernel,
......@@ -126,7 +123,6 @@ def _dense(
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
):
......@@ -144,7 +140,6 @@ def _dense(
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
......@@ -160,7 +155,6 @@ def _dense(
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
......@@ -176,7 +170,6 @@ def _dense_fwd_rule(
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
):
......@@ -203,7 +196,8 @@ def _dense_fwd_rule(
x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -250,7 +244,6 @@ def _dense_bwd_rule(
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
ctx,
grad,
......@@ -280,7 +273,8 @@ def _dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
)
# GEMM NT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment