Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
...@@ -1784,6 +1784,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1784,6 +1784,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -1991,7 +1994,13 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1991,7 +1994,13 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
...@@ -2265,6 +2274,9 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2265,6 +2274,9 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -2403,7 +2415,11 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2403,7 +2415,11 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
if not is_context_parallel: if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos) arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding # dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
...@@ -2739,10 +2755,13 @@ def fused_attn_bwd( ...@@ -2739,10 +2755,13 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability(): # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not ( assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
......
...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F ...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
""" """
Helper function to manage primitive states by name without modifying environment variables. Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods. This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args: Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
...@@ -38,13 +38,13 @@ from ..quantize import ( ...@@ -38,13 +38,13 @@ from ..quantize import (
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
get_quantize_config,
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
should_use_rht, get_quantize_config_with_recipe,
get_global_quantize_recipe,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -169,16 +169,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -169,16 +169,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool: def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht( return isinstance(q, ScaledTensor1x) and q.has_rht_applied
q.scaling_mode, is_colwise=q.is_colwise
)
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
assert uses_rht(lhs_q) == uses_rht(rhs_q), ( "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in" " GEMM."
" the GEMM."
) )
return lhs_q, rhs_q return lhs_q, rhs_q
...@@ -474,29 +471,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -474,29 +471,6 @@ class GemmPrimitive(BasePrimitive):
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
) )
lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs.shape[lhs_axis_boundary:])
if lhs_is_transposed
else reduce(operator.mul, lhs.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs.shape[:rhs_axis_boundary])
if rhs_is_transposed
else reduce(operator.mul, rhs.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
# Determine output shape and dtype # Determine output shape and dtype
assert ( assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1 dtypes.canonicalize_dtype(out_dtype).itemsize > 1
...@@ -560,6 +534,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -560,6 +534,9 @@ class GemmPrimitive(BasePrimitive):
# Declare cuBLAS workspace # Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes() workspace_size = get_cublas_workspace_size_bytes()
# NVFP4 swizzling happen in via nvte kernel instead of JAX transposes
if scaling_mode.is_nvfp4_scaling:
workspace_size += lhs_scale_inv.size + rhs_scale_inv.size
if not collective_op.is_none: if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams() workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
...@@ -605,6 +582,29 @@ class GemmPrimitive(BasePrimitive): ...@@ -605,6 +582,29 @@ class GemmPrimitive(BasePrimitive):
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
) )
lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
if lhs_transposed
else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
if rhs_transposed
else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = { kwargs = {
"scaling_mode": int(scaling_mode.value), "scaling_mode": int(scaling_mode.value),
...@@ -666,6 +666,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -666,6 +666,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
) )
# Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel
if scaling_mode.is_mxfp8_scaling:
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
...@@ -1245,7 +1247,7 @@ def _te_gemm( ...@@ -1245,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False, transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE, collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
...@@ -1257,6 +1259,13 @@ def _te_gemm( ...@@ -1257,6 +1259,13 @@ def _te_gemm(
DeprecationWarning, DeprecationWarning,
) )
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
lhs_data = lhs lhs_data = lhs
rhs_data = rhs rhs_data = rhs
...@@ -1719,10 +1728,15 @@ def _jax_gemm( ...@@ -1719,10 +1728,15 @@ def _jax_gemm(
assert ( assert (
rhs.scaling_mode == lhs.scaling_mode rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = ( precision = (
jax.lax.Precision.HIGHEST jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
) )
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
...@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100. # but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise
)
return ( return (
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100 and arch_l_100
...@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, ...@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated @return: the output of 'f' with the colwise output calculated
""" """
should_apply_war = ( should_apply_war = (
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() quantizer is not None
and quantizer.scaling_mode.is_tensor_scaling()
and quantizer.q_layout.is_rowwise_colwise
) )
if not should_apply_war: if not should_apply_war:
return None return None
......
...@@ -11,7 +11,7 @@ from typing import Optional, Union ...@@ -11,7 +11,7 @@ from typing import Optional, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -27,7 +27,7 @@ from .misc import ( ...@@ -27,7 +27,7 @@ from .misc import (
NamedSharding, NamedSharding,
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl, AmaxScope from .quantization import quantize, AmaxScope
from ..sharding import ( from ..sharding import (
all_reduce_max_along_all_axes_except_PP, all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp, all_reduce_sum_along_dp_fsdp_tpsp,
...@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive):
"Current tensor scaling is not supported for fused norm and quantization. Please do" "Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling." " norm in higher-precision then quantize with current tensor scaling."
) )
assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
"NVFP4 block scaling is not yet supported for fused norm and quantization."
" Please do norm in higher-precision then quantize with current tensor scaling."
)
assert (
not quantize_layout.is_colwise_only
), "Fused norm with colwise-only quantization is not supported."
mu_rsigama_dtype = jnp.float32 mu_rsigama_dtype = jnp.float32
...@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive):
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,) colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray( colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype shape=colwise_scale_inv_shape, dtype=scale_dtype
) )
...@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive):
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
get_forward_sm_margin(), get_forward_sm_margin(),
is_2x, True, # is_training
) )
wkspace_aval = jax.core.ShapedArray( wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
sm_margin=sm_margin, sm_margin=sm_margin,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, quantize_layout=quantize_layout.value.value,
output_amax_when_no_scaling=output_amax_when_no_scaling, output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
...@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape rowwise_scale_inv_shape
) )
if is_2x: if quantize_layout.has_colwise:
colwise_scale_inv = colwise_scale_inv.flatten()[ colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1) : reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape) ].reshape(colwise_scale_inv_shape)
...@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive):
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,) colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
...@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive):
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,) colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
...@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwd_" prefix = "NormFwd"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 value_types[0].shape,
unique_var=prefix,
flatten_axis=-1,
q_layout=quantize_layout,
) )
x_axes = scale_rules.input_spec input_spec = scale_rules.input_spec
out = x_axes rsigma = input_spec[:-1]
colwise_out = out if is_2x else (prefix + "out_colwise",) mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
rsigma = x_axes[:-1] amax = (BATCHING + prefix + "_amax",)
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma scale = (BATCHING + prefix + "_scale",)
amax = (prefix + "amax",) gamma = (BATCHING + prefix + "_gamma",)
beta = (BATCHING + prefix + "_beta",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax, ("…2",), ("…3",)), (input_spec, scale, amax, gamma, beta),
( (
out, scale_rules.rowwise_out_spec,
colwise_out, scale_rules.colwise_out_spec,
scale_rules.rowwise_rule, scale_rules.rowwise_scale_spec,
scale_rules.colwise_rule, scale_rules.colwise_scale_spec,
amax, amax,
mu, mu,
rsigma, rsigma,
...@@ -945,7 +956,7 @@ def layernorm_fwd( ...@@ -945,7 +956,7 @@ def layernorm_fwd(
beta: jnp.ndarray, beta: jnp.ndarray,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False, transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False, output_amax_when_no_scaling: bool = False,
...@@ -975,10 +986,19 @@ def layernorm_fwd( ...@@ -975,10 +986,19 @@ def layernorm_fwd(
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1) - Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
""" """
if not NormFwdPrimitive.enabled(): if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, mu, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -999,7 +1019,7 @@ def layernorm_fwd( ...@@ -999,7 +1019,7 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=False, transpose_batch_sequence=False,
...@@ -1029,7 +1049,7 @@ def layernorm_fwd( ...@@ -1029,7 +1049,7 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False, output_amax_when_no_scaling=False,
) )
out, _ = _quantize_dbias_impl( out, _ = quantize(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
) )
return out, mu, rsigma return out, mu, rsigma
...@@ -1050,20 +1070,19 @@ def layernorm_fwd( ...@@ -1050,20 +1070,19 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True, output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out, out,
is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
return out, mu, rsigma return out, mu, rsigma
is_2x2x = quantizer.is_2x2x() # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common normalization doesn't support 2x delayed scaling q_layout = quantizer.q_layout
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False q_layout = QuantizeLayout.ROWWISE
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -1083,7 +1102,7 @@ def layernorm_fwd( ...@@ -1083,7 +1102,7 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1092,8 +1111,7 @@ def layernorm_fwd( ...@@ -1092,8 +1111,7 @@ def layernorm_fwd(
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -1219,10 +1237,19 @@ def rmsnorm_fwd( ...@@ -1219,10 +1237,19 @@ def rmsnorm_fwd(
Shape: (..., 1) Shape: (..., 1)
""" """
if not NormFwdPrimitive.enabled(): if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1245,7 +1272,7 @@ def rmsnorm_fwd( ...@@ -1245,7 +1272,7 @@ def rmsnorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1274,7 +1301,7 @@ def rmsnorm_fwd( ...@@ -1274,7 +1301,7 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False, output_amax_when_no_scaling=False,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out.data, out.data,
quantizer, quantizer,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -1297,20 +1324,19 @@ def rmsnorm_fwd( ...@@ -1297,20 +1324,19 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True, output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out = quantize(
out, out,
is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
return out, rsigma return out, rsigma
is_2x2x = quantizer.is_2x2x() # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common normalization doesn't support 2x delayed scaling q_layout = quantizer.q_layout
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False q_layout = QuantizeLayout.ROWWISE
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -1330,7 +1356,7 @@ def rmsnorm_fwd( ...@@ -1330,7 +1356,7 @@ def rmsnorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1339,8 +1365,7 @@ def rmsnorm_fwd( ...@@ -1339,8 +1365,7 @@ def rmsnorm_fwd(
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
......
...@@ -11,7 +11,7 @@ import math ...@@ -11,7 +11,7 @@ import math
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -31,7 +31,7 @@ from .misc import ( ...@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import ( from ..sharding import (
all_reduce_max_along_all_axes_except_PP, all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp, all_reduce_sum_along_dp_fsdp,
num_of_devices, get_num_devices_in_mesh,
) )
from ..quantize import ( from ..quantize import (
ScaledTensor2x, ScaledTensor2x,
...@@ -45,7 +45,6 @@ from ..quantize import ( ...@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
should_use_rht,
) )
...@@ -108,21 +107,22 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -108,21 +107,22 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but" "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}" f" received {sr_rng_state_aval}"
) )
if is_outer: if is_outer and get_num_devices_in_mesh() > 1:
assert ( assert (
sr_rng_state_aval.shape[0] == num_of_devices() sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4 and sr_rng_state_aval.shape[1] == 4
), ( ), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}" f" True and is_outer is True but received {sr_rng_state_aval.shape}"
) )
else: else:
assert sr_rng_state_aval.shape == (4,), ( # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
"Sharded sr_rng_state must be of shape (4,) per device when" assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
) )
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if QuantizeLayout(q_layout).has_rowwise:
rowwise_out_shape = out_shape rowwise_out_shape = out_shape
else: else:
rowwise_out_shape = (1,) rowwise_out_shape = (1,)
...@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
broadcast_2d_scale_shape_to_1d=True, broadcast_2d_scale_shape_to_1d=True,
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if QuantizeLayout(q_layout).has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else: else:
...@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(scale_dtype), jax_dtype_to_te_dtype(scale_dtype),
scaling_mode, scaling_mode,
QuantizeLayout( q_layout.value,
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
) )
wkspace_shape = wkspace_info[0] wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
post_rht_amax, post_rht_amax,
rht_matrix, rht_matrix,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout.value.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
is_dbias=is_dbias, is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding, stochastic_rounding=stochastic_rounding,
...@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
...@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
desc="BaseDBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
...@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling: if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ( if (
ScalingMode(scaling_mode).is_block_scaling ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed and ScalingMode(scaling_mode).is_colwise_transposed
...@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
...@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling: if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ( if (
ScalingMode(scaling_mode).is_block_scaling ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed and ScalingMode(scaling_mode).is_colwise_transposed
...@@ -552,8 +550,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -552,8 +550,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -564,6 +567,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -564,6 +567,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) )
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
( (
local_x, local_x,
local_colwise_x, local_colwise_x,
...@@ -635,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -635,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "DBiasQuantize_" prefix = "DBiasQuantize"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, value_types[0].shape,
unique_var=prefix + "x", unique_var=prefix,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
q_layout=q_layout,
broadcast_2d_scale_shape_to_1d=True, broadcast_2d_scale_shape_to_1d=True,
) )
x_axes = scale_rules.input_spec input_spec = scale_rules.input_spec
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
out = x_axes amax = (BATCHING + prefix + "_amax",)
colwise_out = (prefix + "out_colwise",) scale = (BATCHING + prefix + "_scale",)
colwise_scale_inv = (prefix + "colwise_scale_inv",) sr_rng_state = (
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): BATCHING + prefix + "_sr_rng_state_partition_axis",
colwise_scale_inv = scale_rules.colwise_rule BATCHING + prefix + "sr_rng_state_data_axis",
if ScalingMode(scaling_mode).is_colwise_transposed: )
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
colwise_scale_inv = tuple(
multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
)
else:
colwise_out = x_axes
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")
post_rht_amax = (prefix + "post_rht_amax",) post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
dbias,
),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
) )
...@@ -754,9 +758,10 @@ def _quantize_dbias_impl( ...@@ -754,9 +758,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = ( is_unsupported = quantizer.q_layout.is_colwise_only and not (
quantizer.q_layout == QuantizeLayout.COLWISE quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING and hasattr(quantizer, "use_rht")
and quantizer.use_rht
) )
if is_unsupported or not PrimitiveClass.enabled(): if is_unsupported or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
...@@ -792,7 +797,7 @@ def _quantize_dbias_impl( ...@@ -792,7 +797,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16) rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True use_rht = True
rht_matrix = get_rht_matrix() rht_matrix = get_rht_matrix()
...@@ -815,7 +820,7 @@ def _quantize_dbias_impl( ...@@ -815,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling # Make sure to reset amax to zeros for DelayedScaling
...@@ -836,7 +841,7 @@ def _quantize_dbias_impl( ...@@ -836,7 +841,7 @@ def _quantize_dbias_impl(
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x() and quantizer.q_layout.is_rowwise_colwise
and is_1x_kernel_supported and is_1x_kernel_supported
) )
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
...@@ -861,12 +866,16 @@ def _quantize_dbias_impl( ...@@ -861,12 +866,16 @@ def _quantize_dbias_impl(
x.data, x.data,
scale, scale,
amax, amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), (
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
...@@ -875,10 +884,10 @@ def _quantize_dbias_impl( ...@@ -875,10 +884,10 @@ def _quantize_dbias_impl(
use_rht=use_rht, use_rht=use_rht,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE: if q_layout.is_rowwise_only:
# Quantizer requires 2x quantization, but we are using 1x quantization # Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX # for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0: if flatten_axis < 0:
...@@ -902,6 +911,7 @@ def _quantize_dbias_impl( ...@@ -902,6 +911,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout, q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
) )
return out, dbias.astype(dq_dtype) return out, dbias.astype(dq_dtype)
...@@ -1029,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1029,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_rowwise:
rowwise_out_shape = out_shape rowwise_out_shape = out_shape
else: else:
rowwise_out_shape = (1,) rowwise_out_shape = (1,)
...@@ -1038,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1038,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
colwise_out_shape = out_shape colwise_out_shape = out_shape
else: else:
colwise_out_shape = (1,) colwise_out_shape = (1,)
...@@ -1103,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1103,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale, scale,
group_sizes, group_sizes,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout.value.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -1217,7 +1227,7 @@ def grouped_quantize( ...@@ -1217,7 +1227,7 @@ def grouped_quantize(
) )
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups): for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0]) scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in ( is_tensor_scaling = quantizer.scaling_mode in (
...@@ -1226,7 +1236,7 @@ def grouped_quantize( ...@@ -1226,7 +1236,7 @@ def grouped_quantize(
) )
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -1240,7 +1250,7 @@ def grouped_quantize( ...@@ -1240,7 +1250,7 @@ def grouped_quantize(
group_sizes, group_sizes,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_axis=group_axis, group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
...@@ -1248,7 +1258,7 @@ def grouped_quantize( ...@@ -1248,7 +1258,7 @@ def grouped_quantize(
# For DelayedScaling2x and CurrentScaling2x, the scale buffer # For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise # is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
......
...@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); ...@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x); JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
...@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); ...@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout); JAXX_Quantize_Layout quantize_layout);
// Softmax // Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
...@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ...@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
...@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
...@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back(); auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)}; auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
...@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
} }
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape ? output_trans_shape
: output_shape; : output_shape;
...@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // updated_amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
...@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer ...@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer
Buffer_Type amax_buf, Result_Type output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
ActivationConfig act_params, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
output_amax_when_no_scaling); output_amax_when_no_scaling);
} }
...@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // updated_amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size}; auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape; : output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
...@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_dbias, ActivationConfig act_params, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool output_amax_when_no_scaling) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
...@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
} }
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape ? output_trans_shape
: output_shape; : output_shape;
...@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
is_quantize_2x2x(quantize_layout) && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations."); "TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
...@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
...@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI( ...@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool output_amax_when_no_scaling) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
output_amax_when_no_scaling); act_params, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
......
...@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
return backend; return backend;
} }
...@@ -122,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -122,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
...@@ -155,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -155,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
...@@ -173,36 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -173,36 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto ragged_offset_tensor = auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { nvte_fused_attn_fwd(
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
nvte_fused_attn_fwd_qkvpacked( dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
} }
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
...@@ -276,7 +245,8 @@ static void FusedAttnForwardImpl( ...@@ -276,7 +245,8 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -288,47 +258,57 @@ static void FusedAttnForwardImpl( ...@@ -288,47 +258,57 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
// Prepare Q, K, V pointers and shapes based on layout
// Python passes dummy tensors for unused slots, so we extract from the actual packed data
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; // QKV packed in q: [batch*seqlen, 3, heads, dim]
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); // Python passes: q=packed_qkv, k=dummy, v=dummy
nvte_fused_attn_fwd_qkvpacked( // Extract K and V pointers from the packed q data
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), NVTE_CHECK(qk_head_dim == v_head_dim,
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, "For QKV packed layout, qk_head_dim must equal v_head_dim");
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
window_size_left, window_size_right, workspace_tensor.data(), stream); q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
// For packed QKV, all have same shape since they're views into the same packed tensor
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
auto kv_shape = // Python passes: q=query, k=packed_kv, v=dummy
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; // Extract V pointer from the packed k data
auto q_tensor = TensorWrapper(q, q_shape, dtype); NVTE_CHECK(qk_head_dim == v_head_dim,
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); "For KV packed layout, qk_head_dim must equal v_head_dim");
nvte_fused_attn_fwd_kvpacked( size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), q_ptr = q;
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), k_ptr = k;
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), // V has same shape as K since they're packed together
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, v_shape = k_shape;
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
} }
// else NVTE_HD_HD_HD: pointers and shapes already correct
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
...@@ -411,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -411,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
...@@ -447,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -447,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
...@@ -468,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -468,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
s_tensor.data(), // not used for F16 doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), s_tensor.data(), // not used for F16
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
deterministic, query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
nvte_fused_attn_bwd_kvpacked( window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
} }
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
...@@ -542,82 +486,89 @@ static void FusedAttnBackwardImpl( ...@@ -542,82 +486,89 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
void *dq_ptr = dq;
void *dk_ptr = dk;
void *dv_ptr = dv;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; // QKV packed in q: [batch*seqlen, 3, heads, dim]
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); NVTE_CHECK(qk_head_dim == v_head_dim,
if (is_ragged) { "For QKV packed layout, qk_head_dim must equal v_head_dim");
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
stream); q_ptr = q;
} k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
s_tensor.data(), // not used for F16 dq_ptr = dq;
s_tensor.data(), // not used for F16 dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), k_shape = q_shape;
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, v_shape = q_shape;
dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
auto kv_shape = NVTE_CHECK(qk_head_dim == v_head_dim,
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; "For KV packed layout, qk_head_dim must equal v_head_dim");
auto q_tensor = TensorWrapper(q, q_shape, dtype); size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); q_ptr = q;
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); k_ptr = k;
auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
if (is_ragged) { dq_ptr = dq;
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); dk_ptr = dk;
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
stream); // V has same shape as K since they're packed together
} v_shape = k_shape;
nvte_fused_attn_bwd_kvpacked( }
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
s_tensor.data(), // not used for F16 auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream); if (is_ragged) {
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { size_t dtype_size = typeToSize(dtype);
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream);
auto q_tensor = TensorWrapper(q, q_shape, dtype); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto k_tensor = TensorWrapper(k, k_shape, dtype); // Clear dq
auto v_tensor = TensorWrapper(v, v_shape, dtype); cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); // For packed KV, dk contains both dk and dv - clear all at once
auto dk_tensor = TensorWrapper(dk, k_shape, dtype); cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype); } else {
if (is_ragged) { // All separate - clear each individually
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream);
} }
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
} }
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
......
...@@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ...@@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
} }
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, uint8_t *swizzle_scale_ptr,
size_t axis_boundary, bool rowwise) { JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape // Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions(); auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary), std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
...@@ -56,17 +56,32 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -56,17 +56,32 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1}; std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { auto is_nvfp4 = is_nvfp4_scaling(scaling_mode);
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) {
// Block scaling also needs to be collapsed to match 2D data // Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
NVTE_CHECK(typeToSize(scale_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
} }
if (!is_nvfp4) {
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) {
if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
} else { // Swizzle for NVFP4
NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS");
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else { // Create tensor to hold swizzled scale factor
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
} }
} }
...@@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) { bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// cuBLAS workspace + 256 alignment enforcement (+ swizzle scales)
uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr;
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
size_t workspace_size = static_cast<size_t>(workspace->element_count()) - 256;
if (is_nvfp4_scaling(scaling_mode)) {
auto lhs_scale_size = product(lhs_scale_inv.dimensions());
auto rhs_scale_size = product(rhs_scale_inv.dimensions());
workspace_size = workspace_size - lhs_scale_size - rhs_scale_size;
lhs_swizzle_scale_ptr = workspace_ptr;
rhs_swizzle_scale_ptr = workspace_ptr + lhs_scale_size;
workspace_ptr = rhs_swizzle_scale_ptr + rhs_scale_size;
}
auto workspace_ = TensorWrapper(workspace_ptr, std::vector<size_t>{workspace_size}, DType::kByte);
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
lhs_axis_boundary, make_lhs_rowwise); auto [lhs_, lhs_shape] =
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr,
rhs_axis_boundary, make_rhs_rowwise); scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] =
xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr,
scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
...@@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
} }
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.; float one = 1.;
......
...@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeLayout { enum class JAXX_Quantize_Layout : int64_t {
ROWWISE, ROWWISE,
COLWISE, COLWISE,
ROWWISE_COLWISE, ROWWISE_COLWISE,
}; };
inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
enum class JAXX_Scaling_Mode : int64_t { enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0, NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1, DELAYED_TENSOR_SCALING = 1,
......
...@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
auto x_size = product(x_buf.dimensions()); auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions()); auto gamma_size = product(gamma_buf.dimensions());
...@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
} }
if (_is_2x) { if (is_quantize_2x2x(quantize_layout)) {
output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(), output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(),
static_cast<DType>(out_dtype), input_shape); static_cast<DType>(out_dtype), input_shape);
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
...@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardInitializeFFI(
Buffer_Type amax_buf, Buffer_Type gamma_buf, cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf,
Buffer_Type beta_buf, Result_Type output_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon,
Result_Type wkspace_buf, int norm_type, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
bool zero_centered_gamma, double epsilon, int64_t sm_margin, bool output_amax_when_no_scaling) {
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling); scaling_mode, quantize_layout, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
...@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
......
...@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", pybind11::enum_<JAXX_Quantize_Layout>(m, "JAXX_Quantize_Layout", pybind11::module_local())
pybind11::module_local()) .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE)
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) .value("COLWISE", JAXX_Quantize_Layout::COLWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local()) pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
......
...@@ -20,7 +20,7 @@ namespace jax { ...@@ -20,7 +20,7 @@ namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) { JAXX_Quantize_Layout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
...@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1}; auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { if (is_quantize_rowwise(q_layout)) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4) if (is_nvfp4)
...@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
} }
} }
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { if (is_quantize_colwise(q_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape; : output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
...@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool stochastic_rounding, bool use_rht) { int64_t flatten_axis, bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data();
...@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE || if (is_quantize_rowwise(quantize_layout)) {
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_tensor_scaling) { if (is_tensor_scaling) {
...@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
quant_config.set_rng_state(sr_rng_state_tensor.data()); quant_config.set_rng_state(sr_rng_state_tensor.data());
} }
if (quantize_layout == QuantizeLayout::COLWISE || if (is_quantize_colwise(quantize_layout)) {
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_nvfp4 && use_rht) { if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { if (is_quantize_2x2x(quantize_layout)) {
// Do regular rowwise quantization without RHT // Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
} }
...@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis") .Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding") .Attr<bool>("stochastic_rounding")
...@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
Buffer_Type group_sizes, Result_Type outputs, Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs, Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
int64_t flatten_axis) { int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode)); "Unsupported scaling mode: ", static_cast<int>(scaling_mode));
...@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data()); auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data()); auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
...@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data()); auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data()); auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
...@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
...@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype); auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) { if (is_quantize_rowwise(quantize_layout)) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i); out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
} }
} }
if (has_colwise) { if (is_quantize_colwise(quantize_layout)) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape); out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
...@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, ...@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<int64_t>("flatten_axis")); .Attr<int64_t>("flatten_axis"));
} // namespace jax } // namespace jax
......
...@@ -19,6 +19,7 @@ from . import cpp_extensions as tex ...@@ -19,6 +19,7 @@ from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope from .cpp_extensions.amax import AmaxScope
from .quantize import ( from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScaledTensor,
ScalingMode, ScalingMode,
QuantizeLayout, QuantizeLayout,
QuantizerSet, QuantizerSet,
...@@ -26,7 +27,6 @@ from .quantize import ( ...@@ -26,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -94,7 +94,7 @@ def dense( ...@@ -94,7 +94,7 @@ def dense(
if transpose_batch_sequence: if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!") warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -227,8 +227,8 @@ def _dense_fwd_rule( ...@@ -227,8 +227,8 @@ def _dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS), casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS), casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape, x.shape,
kernel.shape, kernel.shape,
use_bias, use_bias,
...@@ -529,8 +529,12 @@ def _grouped_dense_fwd_rule( ...@@ -529,8 +529,12 @@ def _grouped_dense_fwd_rule(
ctx = ( ctx = (
group_sizes, group_sizes,
ctx_x, ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
ctx_kernel, (
ctx_kernel.checkpoint(quantizer_set.kernel)
if isinstance(ctx_kernel, ScaledTensor)
else ctx_kernel
),
x.shape, x.shape,
kernel.shape, kernel.shape,
use_bias, use_bias,
......
...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -33,10 +33,11 @@ from ..cpp_extensions import ( ...@@ -33,10 +33,11 @@ from ..cpp_extensions import (
) )
from ..quantize import ( from ..quantize import (
QuantizerFactory, QuantizerFactory,
get_quantize_config, get_global_quantize_recipe,
QuantizeMetaSet, QuantizeMetaSet,
TensorSource, TensorSource,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
noop_quantizer_set,
) )
PRNGKey = Any PRNGKey = Any
...@@ -345,23 +346,27 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -345,23 +346,27 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
""" """
def generate_quantizer_set( def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None self,
postfix: str = "",
variable_collection: str = None,
quantization_checkpoint_name: Optional[str] = None,
fp8_recipe=None,
): ):
""" """
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else get_quantize_config().COLLECTION_NAME else quantize_config.COLLECTION_NAME
) )
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta( x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x" self, collection_name, postfix, TensorSource.X, "x"
) )
...@@ -375,7 +380,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -375,7 +380,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set fp8_recipe=fp8_recipe,
quantize_meta_set=quantize_meta_set,
checkpoint_name=quantization_checkpoint_name,
) )
return quantizer_set return quantizer_set
...@@ -424,6 +431,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -424,6 +431,8 @@ class DenseGeneral(TransformerEngineBase):
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor. Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -439,6 +448,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -439,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32 dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -483,7 +493,11 @@ class DenseGeneral(TransformerEngineBase): ...@@ -483,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
...@@ -496,7 +510,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -496,7 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, inputs,
...@@ -597,7 +610,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -597,7 +610,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
...@@ -628,6 +641,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -628,6 +641,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor. Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -644,7 +659,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -644,7 +659,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True return_layernorm_output: bool = False
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
...@@ -654,6 +669,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -654,6 +669,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -693,10 +709,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -693,10 +709,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -747,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -747,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape, kernel_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -891,10 +909,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -891,10 +909,10 @@ class LayerNormMLP(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the second dense layer transformation. the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
activation_params: dict, default = None activation_params: dict, default = None
...@@ -903,7 +921,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -903,7 +921,7 @@ class LayerNormMLP(TransformerEngineBase):
need additional parameters. need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.0
Dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = () intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
...@@ -941,6 +959,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -941,6 +959,8 @@ class LayerNormMLP(TransformerEngineBase):
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor. Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -959,11 +979,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -959,11 +979,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",) bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True return_layernorm_output: bool = False
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("gelu",)
activation_params: dict = None activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.0
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
...@@ -976,6 +996,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -976,6 +996,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2" ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -1010,8 +1031,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1010,8 +1031,12 @@ class LayerNormMLP(TransformerEngineBase):
""" """
assert self.axis == -1, "Only support axis == -1 at this moment" assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn1_quantizer_set = self.generate_quantizer_set(
ffn2_quantizer_set = self.generate_quantizer_set("_1") "_0", quantization_checkpoint_name=self.quantization_checkpoint_name
)
ffn2_quantizer_set = self.generate_quantizer_set(
"_1", quantization_checkpoint_name=self.quantization_checkpoint_name
)
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
...@@ -1019,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1019,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision # TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented # when NoOpQuantizer and Tensor are implemented
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -1105,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1105,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
...@@ -1117,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1117,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape, kernel_2_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
......
...@@ -197,6 +197,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -197,6 +197,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
fused_scale_factor = scale_factor fused_scale_factor = scale_factor
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias attn_weights += bias
bias = None
def apply_swa_mask(original_mask: Array) -> Array: def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask""" """Apply the sliding window mask to a given mask"""
...@@ -406,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -406,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable: variable:
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
kernel is not available on the system, a warning will be issued, and the module will attention kernel is not available on the system, a warning will be issued, and the module
automatically fall back to the unfused backend. will automatically fall back to the unfused backend.
.. note:: .. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced The DotProductAttention default setting enables non-deterministic kernels for reduced
...@@ -601,7 +602,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -601,7 +602,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert bias is not None assert bias is not None
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) # Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim] seqlen_q = query.shape[sequence_dim]
...@@ -1618,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1618,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1 attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1 intermediate_dropout: float, default = 0.0
Dropout probability for the dropout op after FC1 layer. Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = () intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer. Dimensions that will share the same dropout mask for hidden after FC1 layer.
...@@ -1633,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1633,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights of FC1 and FC2 layers. Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None mlp_activation_params: dict = None
...@@ -1753,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1753,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
mlp_activation_params: dict = None mlp_activation_params: dict = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
......
...@@ -23,7 +23,6 @@ from .quantize import ( ...@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -73,7 +72,7 @@ def layernorm_dense( ...@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel - Quantization is applied to both the normalized input and kernel
""" """
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -236,8 +235,8 @@ def _layernorm_dense_fwd_rule( ...@@ -236,8 +235,8 @@ def _layernorm_dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS), casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape, x.shape,
kernel.shape, kernel.shape,
mu, mu,
......
...@@ -28,7 +28,6 @@ from .quantize import ( ...@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -114,7 +113,7 @@ def layernorm_mlp( ...@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled(): if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
...@@ -390,11 +389,11 @@ def _layernorm_mlp_fwd_rule( ...@@ -390,11 +389,11 @@ def _layernorm_mlp_fwd_rule(
rsigma, rsigma,
gamma, gamma,
beta, beta,
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS), casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel),
dot_1_output, dot_1_output,
casted_act_out.get_tensor(TensorUsage.LHS_TRANS), casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS), casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel),
x_contracting_dims, x_contracting_dims,
k_contracting_dims, k_contracting_dims,
kernel_1.shape, kernel_1.shape,
......
...@@ -17,3 +17,4 @@ from .metadata import * ...@@ -17,3 +17,4 @@ from .metadata import *
from .hadamard import * from .hadamard import *
from .helper import * from .helper import *
from .device_utils import * from .device_utils import *
from .misc import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment