Unverified Commit a92a0ad2 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Local-Amax for Current-Scaling (#2183)



* Adding Amax Primitive and related args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc for Amax Primitive.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the function name conflict.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modification as feedback suggested.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix errors from lint.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong amax-scope in the bwd.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Added more description for amax-scope
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong attribute name.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Keep dim for AmaxCalcuation.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove keepDim and add shardy_rule
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix shardy_rule
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove extra-collective bytes from ref_coll_count due to local amax.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 2db20a6f
...@@ -76,8 +76,6 @@ class TestDistributedLayernorm: ...@@ -76,8 +76,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
) )
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
) )
......
...@@ -26,7 +26,7 @@ from .misc import ( ...@@ -26,7 +26,7 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding, NamedSharding,
) )
from .quantization import _jax_dbias, _quantize_dbias_impl from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
...@@ -979,6 +979,7 @@ def act_lu( ...@@ -979,6 +979,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -987,6 +988,7 @@ def act_lu( ...@@ -987,6 +988,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1044,7 +1046,13 @@ def act_lu( ...@@ -1044,7 +1046,13 @@ def act_lu(
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out return out
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
......
...@@ -173,7 +173,7 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -173,7 +173,7 @@ class BasePrimitive(metaclass=ABCMeta):
_primitive_registry = {} _primitive_registry = {}
def register_primitive(cls): def register_primitive(cls, outer_only=False):
""" """
Register a JAX primitive and add it to the internal registry. Register a JAX primitive and add it to the internal registry.
""" """
...@@ -186,13 +186,14 @@ def register_primitive(cls): ...@@ -186,13 +186,14 @@ def register_primitive(cls):
def name_of_wrapper_p(): def name_of_wrapper_p():
return cls.name + "_wrapper" return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name) if not outer_only:
dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p = core.Primitive(cls.name)
inner_p.multiple_results = cls.multiple_results dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.multiple_results = cls.multiple_results
inner_p.def_abstract_eval(cls.abstract) inner_p.def_impl(partial(xla.apply_primitive, inner_p))
mlir.register_lowering(inner_p, cls.lowering, platform="cuda") inner_p.def_abstract_eval(cls.abstract)
cls.inner_primitive = inner_p mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
......
...@@ -27,7 +27,7 @@ from .misc import ( ...@@ -27,7 +27,7 @@ from .misc import (
NamedSharding, NamedSharding,
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
...@@ -880,6 +880,7 @@ def layernorm_fwd( ...@@ -880,6 +880,7 @@ def layernorm_fwd(
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization. """Layer normalization forward pass with optional quantization.
...@@ -893,6 +894,7 @@ def layernorm_fwd( ...@@ -893,6 +894,7 @@ def layernorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered. zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -952,7 +954,13 @@ def layernorm_fwd( ...@@ -952,7 +954,13 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out, mu, rsigma return out, mu, rsigma
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
...@@ -1082,6 +1090,7 @@ def rmsnorm_fwd( ...@@ -1082,6 +1090,7 @@ def rmsnorm_fwd(
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization. """Root mean square normalization forward pass with optional quantization.
...@@ -1093,6 +1102,7 @@ def rmsnorm_fwd( ...@@ -1093,6 +1102,7 @@ def rmsnorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered. zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1153,7 +1163,11 @@ def rmsnorm_fwd( ...@@ -1153,7 +1163,11 @@ def rmsnorm_fwd(
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype out.data,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
) )
return out, rsigma return out, rsigma
...@@ -1278,6 +1292,7 @@ def normalization_fwd( ...@@ -1278,6 +1292,7 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1294,6 +1309,7 @@ def normalization_fwd( ...@@ -1294,6 +1309,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1311,12 +1327,27 @@ def normalization_fwd( ...@@ -1311,12 +1327,27 @@ def normalization_fwd(
zero_centered_gamma is not supported if norm_type is 'rmsnorm'. zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
""" """
if norm_type == "layernorm": if norm_type == "layernorm":
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) output, mu, rsigma = layernorm_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
elif norm_type == "rmsnorm": elif norm_type == "rmsnorm":
assert ( assert (
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) output, rsigma = rmsnorm_fwd(
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
mu = None mu = None
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
......
...@@ -6,6 +6,8 @@ import operator ...@@ -6,6 +6,8 @@ import operator
from functools import reduce from functools import reduce
from typing import Tuple, Optional, Union from typing import Tuple, Optional, Union
import math import math
from enum import Enum
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -26,7 +28,12 @@ from .misc import ( ...@@ -26,7 +28,12 @@ from .misc import (
get_min_device_compute_capability, get_min_device_compute_capability,
NamedSharding, NamedSharding,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
global_mesh_resource,
lax_paral_op,
)
from ..quantize import ( from ..quantize import (
ScaledTensor2x, ScaledTensor2x,
ScaledTensor, ScaledTensor,
...@@ -526,6 +533,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): ...@@ -526,6 +533,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (1,) # amax_scope
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
):
"""
amax calcuation abstract
"""
del amax_scope
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
):
"""
amax calcuation implementation
"""
del amax_scope
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
)
if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh)
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
if amax_scope is AmaxScope.FSDP: # Run AR across FSDP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
def _jax_quantize( def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
): ):
...@@ -572,6 +699,7 @@ def _quantize_dbias_impl( ...@@ -572,6 +699,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -628,7 +756,10 @@ def _quantize_dbias_impl( ...@@ -628,7 +756,10 @@ def _quantize_dbias_impl(
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = x.amax amax = x.amax
if amax is None: if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
...@@ -700,6 +831,7 @@ def quantize( ...@@ -700,6 +831,7 @@ def quantize(
x: Union[jnp.ndarray, NoScaleTensor], x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -710,6 +842,7 @@ def quantize( ...@@ -710,6 +842,7 @@ def quantize(
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
is None. is None.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
...@@ -718,6 +851,7 @@ def quantize( ...@@ -718,6 +851,7 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope,
) )
return out return out
...@@ -727,6 +861,7 @@ def quantize_dbias( ...@@ -727,6 +861,7 @@ def quantize_dbias(
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -737,6 +872,8 @@ def quantize_dbias( ...@@ -737,6 +872,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -750,6 +887,7 @@ def quantize_dbias( ...@@ -750,6 +887,7 @@ def quantize_dbias(
quantizer=quantizer, quantizer=quantizer,
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope,
) )
......
...@@ -15,6 +15,7 @@ import jax ...@@ -15,6 +15,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import ( from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScalingMode, ScalingMode,
...@@ -64,6 +65,7 @@ def dense( ...@@ -64,6 +65,7 @@ def dense(
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
using_global_amax_of_x: bool = False,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -77,6 +79,7 @@ def dense( ...@@ -77,6 +79,7 @@ def dense(
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Returns: Returns:
Transformed output tensor Transformed output tensor
...@@ -93,6 +96,7 @@ def dense( ...@@ -93,6 +96,7 @@ def dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, quantizer_set,
using_global_amax_of_x,
) )
return output return output
...@@ -103,6 +107,7 @@ def dense( ...@@ -103,6 +107,7 @@ def dense(
3, 3,
4, 4,
5, 5,
7,
), ),
) )
def _dense( def _dense(
...@@ -113,6 +118,7 @@ def _dense( ...@@ -113,6 +118,7 @@ def _dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, quantizer_set,
using_global_amax_of_x,
): ):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
...@@ -127,6 +133,7 @@ def _dense( ...@@ -127,6 +133,7 @@ def _dense(
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Returns: Returns:
Transformed output tensor Transformed output tensor
...@@ -139,6 +146,7 @@ def _dense( ...@@ -139,6 +146,7 @@ def _dense(
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, quantizer_set,
using_global_amax_of_x,
) )
return output return output
...@@ -151,6 +159,7 @@ def _dense_fwd_rule( ...@@ -151,6 +159,7 @@ def _dense_fwd_rule(
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, quantizer_set,
using_global_amax_of_x,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -175,6 +184,7 @@ def _dense_fwd_rule( ...@@ -175,6 +184,7 @@ def _dense_fwd_rule(
x, x,
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -182,6 +192,7 @@ def _dense_fwd_rule( ...@@ -182,6 +192,7 @@ def _dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -212,7 +223,7 @@ def _dense_fwd_rule( ...@@ -212,7 +223,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
...@@ -238,6 +249,7 @@ def _dense_bwd_rule( ...@@ -238,6 +249,7 @@ def _dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
) )
# GEMM NT # GEMM NT
......
...@@ -21,6 +21,7 @@ import jax.numpy as jnp ...@@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import ( from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
...@@ -272,13 +273,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -272,13 +273,12 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize( casted_kernel_1 = tex.quantize(
kernel_1, kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
) )
# NN GEMM # NN GEMM
...@@ -317,6 +317,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -317,6 +317,7 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2 = tex.quantize( casted_kernel_2 = tex.quantize(
kernel_2, kernel_2,
quantizer=ffn2_quantizer_set.kernel, quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
) )
# NN GEMM # NN GEMM
...@@ -417,6 +418,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -417,6 +418,7 @@ def _layernorm_mlp_bwd_rule(
grad, grad,
is_dbias=use_bias_2, is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad, quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment