Unverified Commit a92a0ad2 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Local-Amax for Current-Scaling (#2183)



* Adding Amax Primitive and related args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc for Amax Primitive.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the function name conflict.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modification as feedback suggested.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix errors from lint.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong amax-scope in the bwd.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Added more description for amax-scope
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong attribute name.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Keep dim for AmaxCalcuation.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove keepDim and add shardy_rule
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix shardy_rule
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove extra-collective bytes from ref_coll_count due to local amax.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2db20a6f
......@@ -76,8 +76,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
......
......@@ -26,7 +26,7 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from .quantization import _jax_dbias, _quantize_dbias_impl
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
......@@ -979,6 +979,7 @@ def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
......@@ -987,6 +988,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
If quantizer is None:
......@@ -1044,7 +1046,13 @@ def act_lu(
activation_type=activation_type,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out
if isinstance(quantizer, DelayedScaleQuantizer):
......
......@@ -173,7 +173,7 @@ class BasePrimitive(metaclass=ABCMeta):
_primitive_registry = {}
def register_primitive(cls):
def register_primitive(cls, outer_only=False):
"""
Register a JAX primitive and add it to the internal registry.
"""
......@@ -186,6 +186,7 @@ def register_primitive(cls):
def name_of_wrapper_p():
return cls.name + "_wrapper"
if not outer_only:
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
......
......@@ -27,7 +27,7 @@ from .misc import (
NamedSharding,
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl
from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
......@@ -880,6 +880,7 @@ def layernorm_fwd(
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization.
......@@ -893,6 +894,7 @@ def layernorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
A tuple containing:
......@@ -952,7 +954,13 @@ def layernorm_fwd(
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
......@@ -1082,6 +1090,7 @@ def rmsnorm_fwd(
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization.
......@@ -1093,6 +1102,7 @@ def rmsnorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
A tuple containing:
......@@ -1153,7 +1163,11 @@ def rmsnorm_fwd(
quantizer=None,
)
out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype
out.data,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out, rsigma
......@@ -1278,6 +1292,7 @@ def normalization_fwd(
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
):
"""Common wrapper for normalization forward pass.
......@@ -1294,6 +1309,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
A tuple containing:
......@@ -1311,12 +1327,27 @@ def normalization_fwd(
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if norm_type == "layernorm":
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
output, mu, rsigma = layernorm_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
elif norm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer)
output, rsigma = rmsnorm_fwd(
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
mu = None
else:
raise ValueError(f"{norm_type=} is not supported.")
......
......@@ -6,6 +6,8 @@ import operator
from functools import reduce
from typing import Tuple, Optional, Union
import math
from enum import Enum
import jax
import jax.numpy as jnp
......@@ -26,7 +28,12 @@ from .misc import (
get_min_device_compute_capability,
NamedSharding,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
global_mesh_resource,
lax_paral_op,
)
from ..quantize import (
ScaledTensor2x,
ScaledTensor,
......@@ -526,6 +533,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (1,) # amax_scope
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
):
"""
amax calcuation abstract
"""
del amax_scope
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
):
"""
amax calcuation implementation
"""
del amax_scope
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
)
if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh)
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
if amax_scope is AmaxScope.FSDP: # Run AR across FSDP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
......@@ -572,6 +699,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -628,7 +756,10 @@ def _quantize_dbias_impl(
# until the tensor is dequantized (e.g. in the GEMM).
amax = x.amax
if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
......@@ -700,6 +831,7 @@ def quantize(
x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -710,6 +842,7 @@ def quantize(
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
is None.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
A ScaledTensor containing the quantized input tensor.
......@@ -718,6 +851,7 @@ def quantize(
x,
quantizer=quantizer,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
)
return out
......@@ -727,6 +861,7 @@ def quantize_dbias(
quantizer: Quantizer,
is_dbias: bool = True,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -737,6 +872,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns:
A tuple containing:
......@@ -750,6 +887,7 @@ def quantize_dbias(
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
)
......
......@@ -15,6 +15,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScalingMode,
......@@ -64,6 +65,7 @@ def dense(
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
using_global_amax_of_x: bool = False,
):
"""Perform dense layer transformation with optional quantization.
......@@ -77,6 +79,7 @@ def dense(
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Returns:
Transformed output tensor
......@@ -93,6 +96,7 @@ def dense(
input_axes,
kernel_axes,
quantizer_set,
using_global_amax_of_x,
)
return output
......@@ -103,6 +107,7 @@ def dense(
3,
4,
5,
7,
),
)
def _dense(
......@@ -113,6 +118,7 @@ def _dense(
input_axes,
kernel_axes,
quantizer_set,
using_global_amax_of_x,
):
"""Internal implementation of dense layer transformation with custom VJP.
......@@ -127,6 +133,7 @@ def _dense(
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Returns:
Transformed output tensor
......@@ -139,6 +146,7 @@ def _dense(
input_axes,
kernel_axes,
quantizer_set,
using_global_amax_of_x,
)
return output
......@@ -151,6 +159,7 @@ def _dense_fwd_rule(
input_axes,
kernel_axes,
quantizer_set,
using_global_amax_of_x,
):
"""Forward pass rule for dense layer transformation.
......@@ -175,6 +184,7 @@ def _dense_fwd_rule(
x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -182,6 +192,7 @@ def _dense_fwd_rule(
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -212,7 +223,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
......@@ -238,6 +249,7 @@ def _dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
)
# GEMM NT
......
......@@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .layernorm import canonicalize_norm_type
from .quantize import (
with_sharding_constraint_by_logical_axes,
......@@ -272,13 +273,12 @@ def _layernorm_mlp_fwd_rule(
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
)
# NN GEMM
......@@ -317,6 +317,7 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2 = tex.quantize(
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
)
# NN GEMM
......@@ -417,6 +418,7 @@ def _layernorm_mlp_bwd_rule(
grad,
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment