Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce
import warnings
import jax
import jax.numpy as jnp
......@@ -21,19 +22,22 @@ from transformer_engine_jax import get_num_compute_streams
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import (
AbstractBaseTensor,
NoScaleTensor,
ScaledTensor,
ScaledTensor2x,
GroupedScaledTensor1x,
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizeConfig,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
)
from ..sharding import global_mesh_resource
from .misc import get_padded_spec
......@@ -148,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
return lhs_q, rhs_q
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops"
original_shape = scale_inv.shape
shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
if is_colwise:
scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
cols, rows = shape_2d
else:
rows, cols = shape_2d
reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
return swizzled.reshape(original_shape)
class GemmPrimitive(BasePrimitive):
"""
Primitive for cuBLAS GEMM
......@@ -226,6 +245,11 @@ class GemmPrimitive(BasePrimitive):
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
else:
assert lhs.dtype == rhs.dtype, (
"For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal."
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
)
# Determine output shape and dtype
assert (
......@@ -277,28 +301,18 @@ class GemmPrimitive(BasePrimitive):
)
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size = 0
rhs_swizzle_size = 0
swizzle_dtype = jnp.uint8
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_swizzle_size = lhs_scale_inv.size
rhs_swizzle_size = rhs_scale_inv.size
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
# Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace
return output, bias_grad, pre_gelu_out, workspace
@staticmethod
def outer_abstract(*args, **kwargs):
outputs = GemmPrimitive.abstract(*args, **kwargs)
return outputs[:-3] # discard workspace arrays
return outputs[:-1] # discard workspace array
@staticmethod
def lowering(
......@@ -365,24 +379,22 @@ class GemmPrimitive(BasePrimitive):
grad,
use_split_accumulator,
):
if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
)
lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1
lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
)
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
outputs = GemmPrimitive.inner_primitive.bind(
lhs,
......@@ -399,7 +411,39 @@ class GemmPrimitive(BasePrimitive):
grad=grad,
use_split_accumulator=use_split_accumulator,
)
return outputs[:-3] # discard workspace arrays
return outputs[:-1] # discard workspace array
@staticmethod
def outer_impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
return GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
)
@staticmethod
def batcher(
......@@ -451,6 +495,19 @@ class GemmPrimitive(BasePrimitive):
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
gsr = global_mesh_resource()
# Ensure that tensor sequence parallelism is not used via setting tp_resource
if gsr.tp_resource is not None:
for i in range(len(lhs_specs) - 1):
if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource:
warnings.warn(
"Tensor sequence parallelism is detected as"
f" tp_resource='{gsr.tp_resource}' appears twice consecutively in"
f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for"
" tensor sequence parallelism to avoid potential issues."
)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
lhs_non_cdims, rhs_non_cdims = map(
......@@ -490,7 +547,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
None if spec is not None and spec == gsr.fsdp_resource else spec
for spec in rhs_non_cspecs
)
# Non-contracting dims of LHS to be gathered along the SP axis.
......@@ -656,6 +714,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims)
......@@ -732,7 +796,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
......@@ -1085,7 +1149,7 @@ def _jax_gemm(
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......@@ -1112,8 +1176,8 @@ def _jax_gemm(
def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
lhs: Union[jnp.ndarray, AbstractBaseTensor],
rhs: Union[jnp.ndarray, AbstractBaseTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
......@@ -1169,6 +1233,11 @@ def gemm(
compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM.
"""
if isinstance(lhs, NoScaleTensor):
lhs = lhs.data
if isinstance(rhs, NoScaleTensor):
rhs = rhs.data
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if lhs_quantizer is None or rhs_quantizer is None:
quantizer_set = kwargs.get("quantizer_set", None)
......
......@@ -193,6 +193,16 @@ def get_min_device_compute_capability():
)
def get_all_device_compute_capability():
"""
Returns a list of compute capability of all local devices.
"""
return tuple(
transformer_engine_jax.get_device_compute_capability(local_gpu_id)
for local_gpu_id in range(len(jax.local_devices()))
)
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
......
......@@ -30,7 +30,7 @@ from .misc import (
)
from .quantization import _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
......@@ -842,9 +842,12 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
output = normed_input * gamma + beta
if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1)
......@@ -866,9 +869,12 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
output = normed_input * gamma
if quantizer:
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
output = output.astype(x.dtype)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(rsigma, axis=-1)
......@@ -930,7 +936,7 @@ def layernorm_fwd(
scale_dtype=jnp.float32,
is_outer=True,
)
return output, mu, rsigma
return NoScaleTensor(data=output, amax=None), mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
......@@ -1064,7 +1070,7 @@ def layernorm_bwd(
)
mu_empty = jnp.zeros(mu.shape, mu.dtype)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, mu_empty, rsigma_empty))
return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty))
return NormBwdPrimitive.outer_primitive.bind(
dz,
x,
......@@ -1133,14 +1139,14 @@ def rmsnorm_fwd(
scale_dtype=jnp.float32,
is_outer=True,
)
return output, rsigma
return NoScaleTensor(data=output, amax=None), rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer)
out, _ = _quantize_dbias_impl(out.data, quantizer)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
......@@ -1152,7 +1158,9 @@ def rmsnorm_fwd(
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype
)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
......@@ -1254,7 +1262,7 @@ def rmsnorm_bwd(
gamma,
)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, rsigma_empty))
return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty))
mu = jnp.empty(())
dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind(
dz,
......@@ -1276,7 +1284,6 @@ def normalization_fwd(
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
):
"""Common wrapper for normalization forward pass.
......@@ -1293,7 +1300,6 @@ def normalization_fwd(
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
A tuple containing:
......@@ -1321,15 +1327,6 @@ def normalization_fwd(
else:
raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
),
mu,
rsigma,
)
return output, mu, rsigma
......
......@@ -4,7 +4,7 @@
"""JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
import math
from packaging import version
......@@ -38,6 +38,7 @@ from ..quantize import (
QuantizeLayout,
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
......@@ -57,13 +58,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
name = "te_dbias_quantize_ffi"
multiple_results = True
impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -72,6 +73,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def abstract(
x_aval,
scale_aval,
amax_aval,
*,
out_dtype,
scaling_mode,
......@@ -95,7 +97,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
updated_amax_aval = amax_aval
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
......@@ -168,6 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
ctx,
x,
scale,
amax,
*,
out_dtype,
scaling_mode,
......@@ -181,13 +184,17 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in
x_aval, scale_aval, amax_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(
BaseDBiasQuantizePrimitive.name,
operand_output_aliases={2: 4}, # donate amax buffer to updated_amax
)(
ctx,
x,
scale,
amax,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
......@@ -198,6 +205,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def impl(
x,
scale,
amax,
out_dtype,
scaling_mode,
q_layout,
......@@ -222,6 +230,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -268,15 +277,15 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del is_outer
check_valid_batch_dims(batch_dims)
assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
x, scale, amax = batched_args
x_bdim, scale_bdim, amax_bdim = batch_dims
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
BaseDBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -303,7 +312,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -329,10 +338,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -341,14 +348,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
return (
out_sharding,
......@@ -375,7 +382,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -401,10 +408,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -432,7 +437,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding,
)
def sharded_impl(x, scale):
def sharded_impl(x, scale, amax):
(
local_x,
local_colwise_x,
......@@ -443,6 +448,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.impl(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -510,7 +516,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
amax = (prefix + "amax",)
return SdyShardingRule(
(x_axes, ("…1",)),
(x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
)
......@@ -530,11 +536,15 @@ def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x
return NoScaleTensor(data=x, amax=None)
return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
if isinstance(dx, NoScaleTensor):
dx = dx.data
sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype
......@@ -553,7 +563,9 @@ def _jax_quantize_dbias(
flatten_axis: int = -1,
):
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x, None
return NoScaleTensor(data=x, amax=None), None
return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
......@@ -561,12 +573,11 @@ def _jax_quantize_dbias(
def _quantize_dbias_impl(
x: jnp.ndarray,
x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -576,28 +587,15 @@ def _quantize_dbias_impl(
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype
dq_dtype = dq_dtype or x.data.dtype
if quantizer is None:
dbias = None
if is_dbias:
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
if noop_scaled_tensor:
# Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
# always works.
return (
ScaledTensorFactory.create_2x(
x,
None,
x,
None,
ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, dbias
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
......@@ -625,19 +623,26 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
scale = jnp.empty((), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
amax = x.amax
if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure amax is init with zero
if amax is None:
amax = jnp.zeros((1,), jnp.float32)
# It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = (
......@@ -657,8 +662,9 @@ def _quantize_dbias_impl(
updated_amax,
dbias,
) = PrimitiveClass.outer_primitive.bind(
x,
x.data,
scale,
amax,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
......@@ -697,10 +703,9 @@ def _quantize_dbias_impl(
def quantize(
x: jnp.ndarray,
x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -710,7 +715,6 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns:
......@@ -720,17 +724,15 @@ def quantize(
x,
quantizer=quantizer,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
return out
def quantize_dbias(
dz: jnp.ndarray,
dz: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
is_dbias: bool = True,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -741,8 +743,6 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns:
A tuple containing:
......@@ -756,7 +756,6 @@ def quantize_dbias(
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
......@@ -931,6 +930,7 @@ def grouped_quantize(
x: jnp.ndarray,
quantizer: GroupedQuantizer,
group_sizes: jnp.ndarray = None,
amax: jnp.ndarray = None,
flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
"""Quantize a tensor in grouped manner.
......@@ -943,6 +943,7 @@ def grouped_quantize(
x: Input tensor to quantize
quantizer: The quantizer to use for quantization
group_sizes: Array of ints containing the size of each group (default: None)
amax: The amax of x; if None, it is auto-generated. (default: None)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
......@@ -957,7 +958,9 @@ def grouped_quantize(
"""
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x
return NoScaleTensor(data=x, amax=None)
# TODO(Phuong): add support for flatten_axis = -2
assert flatten_axis in (
......@@ -985,6 +988,9 @@ def grouped_quantize(
scale = scale.at[i].set(quantizer_i.scale[0])
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if amax is not None:
row_amax = amax
else:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat(
jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
......
......@@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
......@@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor =
......
......@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
}
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode,
size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
......@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
......@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
......@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
......@@ -285,18 +247,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
if (is_tensor_scaling) {
cudaStream_t stream_0 = nvte_get_compute_stream(0);
size_t dpitch = tensor_scaling_sinv_aligment;
size_t spitch = lhs_sinv_dtype_bytes;
size_t width = lhs_sinv_dtype_bytes;
size_t height = lhs_sinv_size;
cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
cudaMemcpyDeviceToDevice, stream);
spitch = rhs_sinv_dtype_bytes;
width = rhs_sinv_dtype_bytes;
height = rhs_sinv_size;
cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
cudaMemcpyDeviceToDevice, stream);
lhs_sinv_ptr = lhs_scatter_aligned_ptr;
rhs_sinv_ptr = rhs_scatter_aligned_ptr;
}
......@@ -565,10 +526,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i));
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans,
lhs_is_trans, grad, workspace_list.data(), accumulate,
use_split_accumulator, num_math_sm, stream);
nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans,
grad, workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
return ffi_with_cuda_error_check();
}
......
......@@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
static_cast<size_t>(scale_inv_buf->dimensions().back())});
}
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
......@@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
static_cast<size_t>(colwise_scale_inv_buf->dimensions().back())});
}
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
......
......@@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type output_trans_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
......@@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......
......@@ -16,13 +16,46 @@ import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import (
ScaledTensorFactory,
ScalingMode,
QuantizeLayout,
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
get_quantize_config,
)
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
assert mesh_axis is not None
assert 0 < axis_idx < len(kernel.shape)
# TODO(Ming Hunag): Add a condition branch for with/without shmap.
kernel_shape = kernel.shape
kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
global_kernel = global_kernel.reshape(*kernel_whole_shape)
return global_kernel
def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
assert mesh_axis is not None
assert 0 < axis_idx < len(scattered_kernel_shape)
# TODO(Ming Hunag): Add a condition branch for with/without shmap.
kernel = kernel.reshape(
*scattered_kernel_shape[:axis_idx],
-1,
scattered_kernel_shape[axis_idx],
*scattered_kernel_shape[axis_idx + 1 :],
)
kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
kernel = kernel.reshape(scattered_kernel_shape)
return kernel
def dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -48,14 +81,10 @@ def dense(
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
output = _dense(
x,
kernel,
......@@ -143,7 +172,9 @@ def _dense_fwd_rule(
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -151,7 +182,6 @@ def _dense_fwd_rule(
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -208,7 +238,6 @@ def _dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# GEMM NT
......@@ -253,10 +282,12 @@ def grouped_dense(
group_sizes: jnp.ndarray,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
bias: jnp.ndarray = None,
kernel_amax: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
kernel_fsdp_info: Tuple[str, int] = (None, -1),
):
"""
Perform grouped dense (linear) layer transformation with optional quantization.
......@@ -268,10 +299,15 @@ def grouped_dense(
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
kernel_amax: The amax values of weight matrix of shape (G,)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
represented in the format (str, int). The first element is the
FSDP mesh axis, and the second element is the dimension along
which the weight is sharded.
Returns:
A jnp.ndarray containing the result of the grouped linear operation
......@@ -282,25 +318,29 @@ def grouped_dense(
group_sizes,
contracting_dims,
bias,
kernel_amax,
precision,
preferred_element_type,
group_offset,
quantizer_set,
kernel_fsdp_info,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
def _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
kernel_amax,
precision,
preferred_element_type,
group_offset,
quantizer_set,
kernel_fsdp_info,
):
output, _ = _grouped_dense_fwd_rule(
x,
......@@ -308,10 +348,12 @@ def _grouped_dense(
group_sizes,
contracting_dims,
bias,
kernel_amax,
precision,
preferred_element_type,
group_offset,
quantizer_set,
kernel_fsdp_info,
)
return output
......@@ -322,21 +364,31 @@ def _grouped_dense_fwd_rule(
group_sizes,
contracting_dims,
bias,
kernel_amax,
precision,
preferred_element_type,
group_offset,
quantizer_set,
kernel_fsdp_info,
):
use_bias = bias is not None
is_noop_quantizer_set = quantizer_set == noop_quantizer_set
kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None
if is_noop_quantizer_set:
grouped_gemm_x = x
grouped_gemm_kernel = kernel
ctx_x = x
ctx_kernel = kernel
flatten_axis_k = None
if kernel_fsdp_enabled:
kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
else:
original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout
x_contracting_dims, k_contracting_dims = contracting_dims
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis
......@@ -352,10 +404,24 @@ def _grouped_dense_fwd_rule(
)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
x,
quantizer_set.x,
group_sizes,
flatten_axis=flatten_axis_x,
)
ctx_kernel_usage = TensorUsage.RHS_TRANS
if kernel_fsdp_enabled:
assert quantizer_set.kernel.scaling_mode in [
ScalingMode.CURRENT_TENSOR_SCALING,
ScalingMode.DELAYED_TENSOR_SCALING,
]
# Perform `cast` only
ctx_kernel_usage = TensorUsage.LHS
quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
......@@ -363,9 +429,51 @@ def _grouped_dense_fwd_rule(
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)
if kernel_fsdp_enabled:
ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
global_ctx_kernel_data = _all_gather_kernel(
ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
)
kernel_shape = global_ctx_kernel_data.shape
ctx_kernel = ScaledTensorFactory.create_1x(
global_ctx_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=False,
data_layout="N",
flatten_axis=ctx_kernel.flatten_axis,
group_sizes=ctx_kernel.group_sizes,
original_shape=kernel_shape,
group_axis=ctx_kernel.group_axis,
)
if is_fp8_gemm_with_all_layouts_supported():
grouped_gemm_kernel = ctx_kernel
else:
grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
grouped_gemm_kernel = ScaledTensorFactory.create_1x(
grouped_gemm_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=True,
data_layout="T",
flatten_axis=ctx_kernel.flatten_axis,
group_sizes=ctx_kernel.group_sizes,
original_shape=kernel_shape,
group_axis=ctx_kernel.group_axis,
)
else:
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
# Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
# This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
output = tex.grouped_gemm(
grouped_gemm_x,
......@@ -393,7 +501,7 @@ def _grouped_dense_fwd_rule(
def _grouped_dense_bwd_rule(
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
......@@ -474,11 +582,17 @@ def _grouped_dense_bwd_rule(
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
if kernel_fsdp_mesh_axis is not None:
wgrad = _psum_scatter_kernel(
wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
)
group_sizes_grad = None
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
dkernel_amax = None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
......@@ -32,7 +32,14 @@ from ..cpp_extensions import (
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
)
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -350,7 +357,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
else get_quantize_config().COLLECTION_NAME
)
scale = self.variable(
collection_name,
......@@ -363,14 +370,14 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,),
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
......@@ -483,7 +490,7 @@ class DenseGeneral(TransformerEngineBase):
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias:
......@@ -692,7 +699,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
quantizer_set = self.generate_quantizer_set()
fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -743,7 +750,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -1005,7 +1012,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -1088,7 +1095,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
......@@ -1100,7 +1107,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......
......@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..attention import CPStrategy
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context"
@nn.compact
......@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_kvpacked():
......@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_separate():
......@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
else:
......@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
......@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context"
@nn.compact
......@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor = self.scale_factor
del self.scale_factor
# case-insensitive mapping for context parallel strategy
cp_strategy_map = {
"DEFAULT": CPStrategy.DEFAULT,
"ALL_GATHER": CPStrategy.ALL_GATHER,
"ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling
"RING": CPStrategy.RING,
}
strategy_key = self.context_parallel_strategy.upper()
if strategy_key in cp_strategy_map:
context_parallel_strategy = cp_strategy_map[strategy_key]
else:
valid_strategies = list(cp_strategy_map.keys())
raise ValueError(
f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
f"Valid options are: {valid_strategies} (case insensitive)"
)
if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout.is_qkvpacked():
......@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)(
query,
......
......@@ -17,7 +17,6 @@ import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import (
ScaledTensor,
Quantizer,
)
......@@ -112,7 +111,7 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps
output, mu, rsigma = tex.normalization_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
)
if isinstance(output, ScaledTensor):
# This is a no-op for higher-precision tensors
output = output.dequantize()
return output, (x, mu, rsigma, gamma, beta, quantizer)
......
......@@ -22,6 +22,7 @@ from .quantize import (
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
get_quantize_config,
)
......@@ -68,6 +69,11 @@ def layernorm_dense(
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel
"""
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
output = _layernorm_dense(
x,
kernel,
......@@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule(
epsilon,
norm_type,
quantizer=quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
kernel,
flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......
......@@ -27,6 +27,7 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
TensorUsage,
get_quantize_config,
)
......@@ -104,6 +105,11 @@ def layernorm_mlp(
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
output = _layernorm_mlp(
x,
gamma,
......@@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule(
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
)
# NN GEMM
......@@ -289,17 +296,27 @@ def _layernorm_mlp_fwd_rule(
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
# This sharding constraint is needed to correct the Shardy sharding propagation
if dot_2_input_axes is not None:
dot_1_output_axes = (
dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
) # add the act_num axis
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
dot_1_output,
activation_type,
quantizer=ffn2_quantizer_set.x,
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
)
# NN GEMM
......@@ -397,7 +414,9 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
grad,
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -438,7 +457,6 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
......@@ -7,9 +7,11 @@ Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple, Dict, Union, Sequence
from typing import Optional, Tuple, Dict, Union, Sequence, Type
from functools import reduce
import operator
......@@ -26,7 +28,7 @@ from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [
"QuantizeConfig",
"get_quantize_config",
"fp8_autocast",
"is_fp8_available",
"update_collections",
......@@ -34,12 +36,15 @@ __all__ = [
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME",
"TensorSource",
]
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
......@@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format):
return jnp.bfloat16, jnp.bfloat16
class TensorSource(Enum):
"""Enumeration for where a tensor's data comes from."""
# Input data
X = 0
# Model parameters
KERNEL = 1
# Gradients in the backward pass
DGRAD = 2
class AmaxComputeAlgo(Enum):
"""Enumeration for AMAX computation algorithms.
......@@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent"
def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
"""Convert recipe.Recipe to ScalingMode.
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!")
class QuantizeConfig:
@dataclass
class BaseQuantizeConfig(ABC):
"""Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats,
......@@ -204,14 +200,13 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
"""
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = "fp8_metas"
COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
......@@ -219,61 +214,82 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
@staticmethod
def is_fp8_enabled():
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
self.FP8_FORMAT = fp8_recipe.fp8_format
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT)
def is_fp8_enabled(self) -> bool:
"""Check if FP8 quantization is enabled.
Returns:
bool: True if quantization is enabled, False otherwise
"""
return QuantizeConfig.INITIALIZED
return self.INITIALIZED
@classmethod
def initialize(cls, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
@abstractmethod
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type.
Args:
fp8_recipe: The FP8 recipe to use for initialization
tensor_source: The usage type for which to get the scaling mode.
Returns:
The scaling mode for the specified usage type.
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
@classmethod
def finalize(cls) -> None:
"""Reset the quantization configuration to default values."""
cls.INITIALIZED = False
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.INFERENCE_MODE = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
def is_supported(self) -> tuple[bool, str]:
"""Check if this QuantizeConfig class is supported on the available devices.
Returns:
bool: True if the class is supported, False otherwise
str: Reason for being unsupported, if applicable.
"""
class DelayedScalingQuantizeConfig:
x_scaling_mode = self.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
is_supported, reason = is_fp8_available(scaling_mode=scaling_mode)
if not is_supported:
return is_supported, reason
return True, None
class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Configuration class higher-precision non-quantized operation."""
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize no-op configuration."""
raise NotImplementedError(
"NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
" higher-precision when no quantized recipe is set."
)
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.NO_SCALING
class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize delayed scaling FP8 configuration.
Args:
......@@ -282,6 +298,8 @@ class DelayedScalingQuantizeConfig:
Raises:
AssertionError: If recipe parameters are not supported
"""
super().initialize_from_recipe(fp8_recipe)
assert fp8_recipe.amax_compute_algo in [
"max",
"most_recent",
......@@ -291,71 +309,88 @@ class DelayedScalingQuantizeConfig:
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
string_to_amax_compute_algo = {
"max": AmaxComputeAlgo.MAX,
"most_recent": AmaxComputeAlgo.MOST_RECENT,
}
cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
cls.FP8_2X_ACC_DGRAD = True
cls.FP8_2X_ACC_WGRAD = True
self.FP8_2X_ACC_DGRAD = True
self.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
"""Reset the delayed scaling configuration."""
QuantizeConfig.finalize()
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.DELAYED_TENSOR_SCALING
class CurrentScalingQuantizeConfig:
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
super().initialize_from_recipe(fp8_recipe)
self.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the current scaling configuration."""
QuantizeConfig.finalize()
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.CURRENT_TENSOR_SCALING
class BlockScalingQuantizeConfig:
class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
super().initialize_from_recipe(fp8_recipe)
self.AMAX_HISTORY_LEN = 0
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.MXFP8_1D_SCALING
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
QuantizeConfig.finalize()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by fp8_autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class(
fp8_recipe: recipe.Recipe,
) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration based on the FP8 recipe.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Returns:
The quantization config class corresponding to the given recipe.
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return CurrentScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
@contextmanager
......@@ -404,25 +439,22 @@ def fp8_autocast(
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
global _QUANTIZE_CONFIG
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig
old_quantize_config = _QUANTIZE_CONFIG
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
try:
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe))
assert fp8_available, reason_for_no_fp8
Config.initialize(fp8_recipe)
_QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported()
assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe)
yield
finally:
Config.finalize()
_QUANTIZE_CONFIG = old_quantize_config
def get_delayed_scaling():
......@@ -440,12 +472,12 @@ def get_delayed_scaling():
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
"max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN,
margin=int(get_quantize_config().MARGIN),
fp8_format=get_quantize_config().FP8_FORMAT,
amax_history_len=get_quantize_config().AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
......@@ -584,6 +616,3 @@ def apply_padding_to_scale_inv(
# Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
......@@ -19,11 +19,18 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .tensor import (
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
ScaledTensorFactory,
NoScaleTensor,
)
from .helper import (
QuantizeConfig,
get_quantize_config,
get_quantize_config_class,
AmaxComputeAlgo,
_get_scaling_mode,
TensorSource,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
......@@ -56,7 +63,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
......@@ -216,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer):
data_layout: str = "NT"
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
self,
x: Union[jnp.ndarray, NoScaleTensor],
is_colwise=False,
dq_dtype=None,
flatten_axis=-1,
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
......@@ -228,14 +239,17 @@ class CurrentScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,))
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / scale
......@@ -262,7 +276,10 @@ class CurrentScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
......@@ -320,7 +337,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
......@@ -346,11 +363,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
scaled_x = x.data.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
......@@ -359,7 +379,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
self.update(amax)
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
......@@ -397,7 +418,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value
"""
# 2. Calculate the current scale
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
......@@ -459,6 +480,10 @@ class BlockScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
if isinstance(x, NoScaleTensor):
# No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x.
x = x.data
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
......@@ -494,7 +519,7 @@ class BlockScaleQuantizer(Quantizer):
return ScaledTensorFactory.create_1x(
x_q,
scales_q,
self.scaling_mode,
scaling_mode=self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
......@@ -640,11 +665,11 @@ class GroupedQuantizer(Quantizer):
return ScaledTensorFactory.create_1x(
grouped_data,
grouped_scale_inv,
self.scaling_mode,
tensor_list[0].dq_dtype,
tensor_list[0].is_colwise,
tensor_list[0].data_layout,
tensor_list[0].flatten_axis,
scaling_mode=self.scaling_mode,
dq_dtype=tensor_list[0].dq_dtype,
is_colwise=tensor_list[0].is_colwise,
data_layout=tensor_list[0].data_layout,
flatten_axis=tensor_list[0].flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
......@@ -827,12 +852,21 @@ class QuantizerFactory:
@staticmethod
def _create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
x_scaling_mode,
kernel_scaling_mode,
grad_scaling_mode,
fwd_dtype,
bwd_dtype,
is_2x2x,
n_groups,
**kwargs,
) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
Args:
scaling_mode: Scaling mode to use
x_scaling_mode: Scaling mode to use for input tensor 'x'
kernel_scaling_mode: Scaling mode to use for kernel tensor
grad_scaling_mode: Scaling mode to use for gradient tensor
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
......@@ -846,9 +880,9 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling():
if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
if get_quantize_config().INFERENCE_MODE:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -868,12 +902,12 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
)
q_dgrad = QuantizerFactory.create(
1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
......@@ -892,10 +926,10 @@ class QuantizerFactory:
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
......@@ -912,27 +946,44 @@ class QuantizerFactory:
)
if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
scaling_mode = _get_scaling_mode(fp8_recipe)
quantize_config = get_quantize_config_class(fp8_recipe)()
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
elif scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
# TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
elif x_scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
for _ in range(n_quantizer_sets):
q_set.append(
QuantizerFactory._create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
x_scaling_mode=x_scaling_mode,
kernel_scaling_mode=kernel_scaling_mode,
grad_scaling_mode=grad_scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
**kwargs,
)
)
......
......@@ -166,6 +166,90 @@ class ScalingModeMetadataImpl(ABC):
"""
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for no scaling mode.
This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
"""
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.
Returns:
The data type used for scale tensors (float32)
"""
return jnp.float32
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise, is_padded, flatten_axis
return (0,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return QuantizeLayout.ROWWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del data_shape, group_axis, is_colwise
assert isinstance(n_groups, int)
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for current scaling mode.
......@@ -396,7 +480,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# if get_quantize_config().INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
......@@ -740,5 +824,5 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
}
......@@ -25,6 +25,8 @@ from ..sharding import (
__all__ = [
"TensorUsage",
"AbstractBaseTensor",
"NoScaleTensor",
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
......@@ -34,14 +36,9 @@ __all__ = [
]
@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors.
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
class AbstractBaseTensor(ABC):
"""Abstract base class for all tensor types."""
@classmethod
def tree_unflatten(cls, aux_data, children):
......@@ -93,9 +90,76 @@ class ScaledTensor(ABC):
"""
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
"""Abstract base class for single layout tensors."""
data: jnp.ndarray
amax: jnp.ndarray
@register_pytree_node_class
@dataclass
class NoScaleTensor(AbstractBaseTensor1x):
"""Higher-precision tensor."""
def __post_init__(self):
assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.amax)
aux_data = ()
return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying array."""
return self.data.ndim
def dequantize(self):
"""This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
return self.data
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
assert (
q_layout == QuantizeLayout.ROWWISE
), "Only ROWWISE layout is supported for NoScaleTensor"
return self
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)
return NoScaleTensor(
data=data,
amax=self.amax,
)
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors."""
@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
"""Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor,
......@@ -104,6 +168,7 @@ class ScaledTensor1x(ScaledTensor):
Attributes:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
......@@ -112,7 +177,6 @@ class ScaledTensor1x(ScaledTensor):
flatten_axis: The quantization axis for the tensor
"""
data: jnp.ndarray
scale_inv: jnp.ndarray
scaling_mode: ScalingMode
dq_dtype: jnp.dtype
......@@ -152,7 +216,7 @@ class ScaledTensor1x(ScaledTensor):
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
children = (self.data, self.amax, self.scale_inv)
aux_data = (
self.scaling_mode,
self.dq_dtype,
......@@ -224,6 +288,7 @@ class ScaledTensor1x(ScaledTensor):
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
amax=self.amax,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
......@@ -255,6 +320,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self,
data,
scale_inv,
amax,
group_sizes,
scaling_mode,
dq_dtype,
......@@ -270,7 +336,15 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.original_shape = original_shape
self.group_axis = group_axis
super().__init__(
data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=_dq_func,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
def __post_init__(self):
......@@ -308,7 +382,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv, self.group_sizes)
children = (self.data, self.scale_inv, self.amax, self.group_sizes)
aux_data = (
self.scaling_mode,
self.dq_dtype,
......@@ -327,7 +401,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
"""Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors.
......@@ -413,7 +487,8 @@ class ScaledTensorFactory:
def create_1x(
data,
scale_inv,
scaling_mode,
amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
......@@ -427,18 +502,22 @@ class ScaledTensorFactory:
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None)
group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
......@@ -468,6 +547,7 @@ class ScaledTensorFactory:
return GroupedScaledTensor1x(
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.grouped_dequantize,
......@@ -485,14 +565,15 @@ class ScaledTensorFactory:
flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x(
data,
scale_inv,
scaling_mode,
dq_dtype,
dequantizer.dequantize,
is_colwise,
data_layout,
flatten_axis,
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.dequantize,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
@staticmethod
......@@ -501,7 +582,8 @@ class ScaledTensorFactory:
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
data_layout="NN",
flatten_axis=-1,
......@@ -516,6 +598,7 @@ class ScaledTensorFactory:
scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
amax: The maximum absolute value of the tensor
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
......@@ -527,10 +610,14 @@ class ScaledTensorFactory:
Returns:
A ScaledTensor2x instance
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=False,
......@@ -543,6 +630,7 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=True,
......@@ -560,7 +648,8 @@ class ScaledTensorFactory:
scale_inv: jnp.ndarray,
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
......@@ -594,6 +683,7 @@ class ScaledTensorFactory:
scale_inv,
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
data_layout=data_layout,
......@@ -608,6 +698,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......@@ -621,6 +712,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......@@ -645,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
if isinstance(x, GroupedScaledTensor1x):
raise NotImplementedError
if isinstance(x, ScaledTensor):
if isinstance(x, AbstractBaseTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
......@@ -9,10 +9,8 @@ tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import warnings
import jax
......@@ -43,44 +41,56 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
return mesh.shape[resource], resource
def _validate_mesh_resource_configuration(mesh_resource):
"""Validate that the mesh resource configuration is consistent and conflict-free."""
is_dp_enabled = (
mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
)
is_tp_enabled = (
mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
)
is_tpsp_enabled = (
mesh_resource.tpsp_resource is not None
and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
)
is_fsdp_enabled = (
mesh_resource.fsdp_resource is not None
and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
)
assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={mesh_resource.dp_resource} and"
f" fsdp_resource={mesh_resource.fsdp_resource}"
)
assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={mesh_resource.tp_resource} and"
f" tpsp_resource={mesh_resource.tpsp_resource}"
)
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = (
[gsr.fsdp_resource, gsr.dp_resource]
if IS_FSDP_OUTER
else [gsr.dp_resource, gsr.fsdp_resource]
)
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
SEQLEN_TP_AXES: gsr.tpsp_resource,
SEQLEN_CP_AXES: gsr.cp_resource,
HEAD_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
......@@ -155,7 +165,7 @@ def with_sharding_constraint_by_logical_axes(
flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED
)
except ImportError:
pass
......@@ -274,6 +284,7 @@ class MeshResource:
Attributes:
dp_resource: Axis name for data parallelism (batch sharding), default is None
tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None
fsdp_resource: Axis name for full-sharded data parallelism, default is None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
......@@ -281,12 +292,13 @@ class MeshResource:
dp_resource: str = None
tp_resource: str = None
tpsp_resource: str = None
fsdp_resource: str = None
pp_resource: str = None
cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource()
_GLOBAL_MESH_RESOURCE = None
@contextmanager
......@@ -314,6 +326,12 @@ def global_mesh_resource() -> MeshResource:
Returns:
The current MeshResource instance
"""
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
return _GLOBAL_MESH_RESOURCE
......@@ -346,52 +364,3 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource
global_shard_resource = global_mesh_resource
class MajorShardingType(Enum):
"""Enumeration of major sharding types for distributed training.
This enum defines the basic sharding patterns available for distributed
training. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: Single process training
DP: Data parallel training
TP: Standard tensor parallel training
DPTP: Data and standard tensor parallel training
"""
SINGLE = 0
DP = 1
TP = 2
DPTP = 3
class ShardingType(Enum):
"""Enumeration of detailed sharding types for distributed training.
This enum defines specific sharding patterns for distributed training,
including combinations of data parallelism and different tensor parallelism
strategies. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: No sharding
DP: Sharding along data parallelism
TP_COL: Sharding along column-split tensor parallelism
TP_ROW: Sharding along row-split tensor parallelism
DP_TP_COL: Sharding along data and column-split tensor parallelism
DP_TP_ROW: Sharding along data and row-split tensor parallelism
"""
SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp")
TP_COL = (MajorShardingType.TP, "tp_col")
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
......@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear, BatchedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.module import UserBufferQuantizationMode
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention import InferenceParams
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment