Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
......@@ -1784,6 +1784,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
......@@ -1991,7 +1994,13 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
......@@ -2265,6 +2274,9 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
......@@ -2403,7 +2415,11 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
......@@ -2739,10 +2755,13 @@ def fused_attn_bwd(
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability():
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
......
......@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods.
This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
......@@ -38,13 +38,13 @@ from ..quantize import (
ScalingMode,
Quantizer,
GroupedQuantizer,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
should_use_rht,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -169,16 +169,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht(
q.scaling_mode, is_colwise=q.is_colwise
)
def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and q.has_rht_applied
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class
assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
" with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" GEMM."
)
return lhs_q, rhs_q
......@@ -474,29 +471,6 @@ class GemmPrimitive(BasePrimitive):
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
)
lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs.shape[lhs_axis_boundary:])
if lhs_is_transposed
else reduce(operator.mul, lhs.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs.shape[:rhs_axis_boundary])
if rhs_is_transposed
else reduce(operator.mul, rhs.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
# Determine output shape and dtype
assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1
......@@ -560,6 +534,9 @@ class GemmPrimitive(BasePrimitive):
# Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes()
# NVFP4 swizzling happen in via nvte kernel instead of JAX transposes
if scaling_mode.is_nvfp4_scaling:
workspace_size += lhs_scale_inv.size + rhs_scale_inv.size
if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
......@@ -605,6 +582,29 @@ class GemmPrimitive(BasePrimitive):
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
)
lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
if lhs_transposed
else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
if rhs_transposed
else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = {
"scaling_mode": int(scaling_mode.value),
......@@ -666,6 +666,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
)
# Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel
if scaling_mode.is_mxfp8_scaling:
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
......@@ -1245,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]:
......@@ -1257,6 +1259,13 @@ def _te_gemm(
DeprecationWarning,
)
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
......@@ -1719,10 +1728,15 @@ def _jax_gemm(
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = (
jax.lax.Precision.HIGHEST
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
......@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise
)
return (
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100
......@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
quantizer is not None
and quantizer.scaling_mode.is_tensor_scaling()
and quantizer.q_layout.is_rowwise_colwise
)
if not should_apply_war:
return None
......
......@@ -11,7 +11,7 @@ from typing import Optional, Union
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec
......@@ -27,7 +27,7 @@ from .misc import (
NamedSharding,
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl, AmaxScope
from .quantization import quantize, AmaxScope
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp,
......@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive):
"Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling."
)
assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
"NVFP4 block scaling is not yet supported for fused norm and quantization."
" Please do norm in higher-precision then quantize with current tensor scaling."
)
assert (
not quantize_layout.is_colwise_only
), "Fused norm with colwise-only quantization is not supported."
mu_rsigama_dtype = jnp.float32
......@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive):
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
......@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive):
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
True, # is_training
)
wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
quantize_layout=quantize_layout.value.value,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
......@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if is_2x:
if quantize_layout.has_colwise:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
......@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive):
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
......@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive):
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
......@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive):
result_types,
)
prefix = "NormFwd_"
prefix = "NormFwd"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
value_types[0].shape,
unique_var=prefix,
flatten_axis=-1,
q_layout=quantize_layout,
)
x_axes = scale_rules.input_spec
input_spec = scale_rules.input_spec
out = x_axes
colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1]
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (prefix + "amax",)
rsigma = input_spec[:-1]
mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
gamma = (BATCHING + prefix + "_gamma",)
beta = (BATCHING + prefix + "_beta",)
return SdyShardingRule(
(x_axes, ("…1",), amax, ("…2",), ("…3",)),
(input_spec, scale, amax, gamma, beta),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
mu,
rsigma,
......@@ -945,7 +956,7 @@ def layernorm_fwd(
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
......@@ -975,10 +986,19 @@ def layernorm_fwd(
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, mu, rsigma)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -999,7 +1019,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=False,
......@@ -1029,7 +1049,7 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out, _ = quantize(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
)
return out, mu, rsigma
......@@ -1050,20 +1070,19 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1083,7 +1102,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1092,8 +1111,7 @@ def layernorm_fwd(
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1219,10 +1237,19 @@ def rmsnorm_fwd(
Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, rsigma)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1245,7 +1272,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1274,7 +1301,7 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out.data,
quantizer,
amax_scope=amax_scope,
......@@ -1297,20 +1324,19 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1330,7 +1356,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1339,8 +1365,7 @@ def rmsnorm_fwd(
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......
......@@ -11,7 +11,7 @@ import math
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
num_of_devices,
get_num_devices_in_mesh,
)
from ..quantize import (
ScaledTensor2x,
......@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
should_use_rht,
)
......@@ -108,21 +107,22 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}"
)
if is_outer:
if is_outer and get_num_devices_in_mesh() > 1:
assert (
sr_rng_state_aval.shape[0] == num_of_devices()
sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4
), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}"
)
else:
assert sr_rng_state_aval.shape == (4,), (
"Sharded sr_rng_state must be of shape (4,) per device when"
# We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if QuantizeLayout(q_layout).has_rowwise:
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
......@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
broadcast_2d_scale_shape_to_1d=True,
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if QuantizeLayout(q_layout).has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
......@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(scale_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
q_layout.value,
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
post_rht_amax,
rht_matrix,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
q_layout=q_layout.value.value,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding,
......@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
......@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*x_spec),
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
......@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
......@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
......@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
......@@ -552,8 +550,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -564,6 +567,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
(
local_x,
local_colwise_x,
......@@ -635,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
result_types,
)
prefix = "DBiasQuantize_"
prefix = "DBiasQuantize"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape,
unique_var=prefix + "x",
unique_var=prefix,
flatten_axis=flatten_axis,
q_layout=q_layout,
broadcast_2d_scale_shape_to_1d=True,
)
x_axes = scale_rules.input_spec
out = x_axes
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "colwise_scale_inv",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv = scale_rules.colwise_rule
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
colwise_scale_inv = tuple(
multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
)
else:
colwise_out = x_axes
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")
input_spec = scale_rules.input_spec
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
sr_rng_state = (
BATCHING + prefix + "_sr_rng_state_partition_axis",
BATCHING + prefix + "sr_rng_state_data_axis",
)
post_rht_amax = (prefix + "post_rht_amax",)
rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2")
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
return SdyShardingRule(
(x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
(input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
(
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
dbias,
),
**scale_rules.factor_sizes,
)
......@@ -754,9 +758,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = (
quantizer.q_layout == QuantizeLayout.COLWISE
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING
is_unsupported = quantizer.q_layout.is_colwise_only and not (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and hasattr(quantizer, "use_rht")
and quantizer.use_rht
)
if is_unsupported or not PrimitiveClass.enabled():
if is_dbias:
......@@ -792,7 +797,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout):
if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True
rht_matrix = get_rht_matrix()
......@@ -815,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling
......@@ -836,7 +841,7 @@ def _quantize_dbias_impl(
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x()
and quantizer.q_layout.is_rowwise_colwise
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout
......@@ -861,12 +866,16 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32),
(
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
......@@ -875,10 +884,10 @@ def _quantize_dbias_impl(
use_rht=use_rht,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE:
if q_layout.is_rowwise_only:
# Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0:
......@@ -902,6 +911,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
)
return out, dbias.astype(dq_dtype)
......@@ -1029,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
flatten_axis=flatten_axis,
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_rowwise:
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
......@@ -1038,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
......@@ -1103,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale,
group_sizes,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
q_layout=q_layout.value.value,
flatten_axis=flatten_axis,
)
......@@ -1217,7 +1227,7 @@ def grouped_quantize(
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in (
......@@ -1226,7 +1236,7 @@ def grouped_quantize(
)
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
(
rowwise_casted_output,
......@@ -1240,7 +1250,7 @@ def grouped_quantize(
group_sizes,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(),
......@@ -1248,7 +1258,7 @@ def grouped_quantize(
# For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
......
......@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
......@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout);
JAXX_Quantize_Layout quantize_layout);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
......@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
......@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2];
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
......@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
}
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
......@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
......@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
output_amax_when_no_scaling);
}
......@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
std::vector<size_t>{1});
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
......@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
......@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
}
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
......@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
is_quantize_2x2x(quantize_layout) && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
......@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
......@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
act_params, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
......
......@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
return backend;
}
......@@ -122,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
......@@ -155,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
......@@ -173,36 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
}
nvte_tensor_pack_destroy(&aux_output_tensors);
......@@ -276,7 +245,8 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -288,47 +258,57 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
// Prepare Q, K, V pointers and shapes based on layout
// Python passes dummy tensors for unused slots, so we extract from the actual packed data
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
// QKV packed in q: [batch*seqlen, 3, heads, dim]
// Python passes: q=packed_qkv, k=dummy, v=dummy
// Extract K and V pointers from the packed q data
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
NVTE_CHECK(qk_head_dim == v_head_dim,
"For QKV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
// For packed QKV, all have same shape since they're views into the same packed tensor
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
// Python passes: q=query, k=packed_kv, v=dummy
// Extract V pointer from the packed k data
NVTE_CHECK(qk_head_dim == v_head_dim,
"For KV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_ptr = q;
k_ptr = k;
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
// V has same shape as K since they're packed together
v_shape = k_shape;
}
// else NVTE_HD_HD_HD: pointers and shapes already correct
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......@@ -411,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
......@@ -447,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
......@@ -468,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
}
nvte_tensor_pack_destroy(&aux_input_tensors);
......@@ -542,82 +486,89 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
/* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
void *dq_ptr = dq;
void *dk_ptr = dk;
void *dv_ptr = dv;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
// QKV packed in q: [batch*seqlen, 3, heads, dim]
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
NVTE_CHECK(qk_head_dim == v_head_dim,
"For QKV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
dq_ptr = dq;
dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
NVTE_CHECK(qk_head_dim == v_head_dim,
"For KV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_ptr = q;
k_ptr = k;
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
dq_ptr = dq;
dk_ptr = dk;
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
// V has same shape as K since they're packed together
v_shape = k_shape;
}
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);
if (is_ragged) {
size_t dtype_size = typeToSize(dtype);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
// For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once
cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
// Clear dq
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
// For packed KV, dk contains both dk and dv - clear all at once
cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream);
} else {
// All separate - clear each individually
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......
......@@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
}
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode,
size_t axis_boundary, bool rowwise) {
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, uint8_t *swizzle_scale_ptr,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
......@@ -56,17 +56,32 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
auto is_nvfp4 = is_nvfp4_scaling(scaling_mode);
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) {
// Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
NVTE_CHECK(typeToSize(scale_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
}
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) {
if (!is_nvfp4) {
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
} else { // Swizzle for NVFP4
NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS");
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
}
}
......@@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// cuBLAS workspace + 256 alignment enforcement (+ swizzle scales)
uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr;
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
size_t workspace_size = static_cast<size_t>(workspace->element_count()) - 256;
if (is_nvfp4_scaling(scaling_mode)) {
auto lhs_scale_size = product(lhs_scale_inv.dimensions());
auto rhs_scale_size = product(rhs_scale_inv.dimensions());
workspace_size = workspace_size - lhs_scale_size - rhs_scale_size;
lhs_swizzle_scale_ptr = workspace_ptr;
rhs_swizzle_scale_ptr = workspace_ptr + lhs_scale_size;
workspace_ptr = rhs_swizzle_scale_ptr + rhs_scale_size;
}
auto workspace_ = TensorWrapper(workspace_ptr, std::vector<size_t>{workspace_size}, DType::kByte);
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise);
auto [lhs_, lhs_shape] =
xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr,
scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] =
xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr,
scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
......@@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
}
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.;
......
......@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeLayout {
enum class JAXX_Quantize_Layout : int64_t {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
......
......@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
......@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
if (_is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(),
static_cast<DType>(out_dtype), input_shape);
output_tensor.set_columnwise_scale_inv(
......@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
Error_Type NormForwardInitializeFFI(
cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
scaling_mode, quantize_layout, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
......@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
......
......@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
pybind11::enum_<JAXX_Quantize_Layout>(m, "JAXX_Quantize_Layout", pybind11::module_local())
.value("ROWWISE", JAXX_Quantize_Layout::ROWWISE)
.value("COLWISE", JAXX_Quantize_Layout::COLWISE)
.value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE)
.export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
......
......@@ -20,7 +20,7 @@ namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
JAXX_Quantize_Layout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
......@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
if (is_quantize_rowwise(q_layout)) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
......@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
}
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
if (is_quantize_colwise(q_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
......@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
int64_t flatten_axis, bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data();
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
......@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_rowwise(quantize_layout)) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_tensor_scaling) {
......@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
quant_config.set_rng_state(sr_rng_state_tensor.data());
}
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_colwise(quantize_layout)) {
if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_2x2x(quantize_layout)) {
// Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
......@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
......@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
......@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
......@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
......@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0;
size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
......@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) {
if (is_quantize_rowwise(quantize_layout)) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) {
......@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
}
}
if (has_colwise) {
if (is_quantize_colwise(quantize_layout)) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
......@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<int64_t>("flatten_axis"));
} // namespace jax
......
......@@ -19,6 +19,7 @@ from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScaledTensor,
ScalingMode,
QuantizeLayout,
QuantizerSet,
......@@ -26,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
get_quantize_config,
)
......@@ -94,7 +94,7 @@ def dense(
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......@@ -227,8 +227,8 @@ def _dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape,
kernel.shape,
use_bias,
......@@ -529,8 +529,12 @@ def _grouped_dense_fwd_rule(
ctx = (
group_sizes,
ctx_x,
ctx_kernel,
ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
(
ctx_kernel.checkpoint(quantizer_set.kernel)
if isinstance(ctx_kernel, ScaledTensor)
else ctx_kernel
),
x.shape,
kernel.shape,
use_bias,
......
......@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
"""
from functools import reduce
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import numpy as np
import jax.numpy as jnp
......@@ -33,10 +33,11 @@ from ..cpp_extensions import (
)
from ..quantize import (
QuantizerFactory,
get_quantize_config,
get_global_quantize_recipe,
QuantizeMetaSet,
TensorSource,
get_quantize_config_with_recipe,
noop_quantizer_set,
)
PRNGKey = Any
......@@ -345,23 +346,27 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
"""
def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
self,
postfix: str = "",
variable_collection: str = None,
quantization_checkpoint_name: Optional[str] = None,
fp8_recipe=None,
):
"""
Generate a set of FP8 meta for a GEMM.
"""
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
else quantize_config.COLLECTION_NAME
)
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
......@@ -375,7 +380,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
fp8_recipe=fp8_recipe,
quantize_meta_set=quantize_meta_set,
checkpoint_name=quantization_checkpoint_name,
)
return quantizer_set
......@@ -424,6 +431,8 @@ class DenseGeneral(TransformerEngineBase):
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
features: Union[Iterable[int], int]
......@@ -439,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -483,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
if self.use_bias:
......@@ -496,7 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else:
bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs,
......@@ -597,7 +610,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
......@@ -628,6 +641,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
features: Union[Iterable[int], int]
......@@ -644,7 +659,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
return_layernorm_output: bool = False
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
......@@ -654,6 +669,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -693,10 +709,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
input_dtype = inputs.dtype
ln_output = None
quantizer_set = self.generate_quantizer_set()
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -747,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -891,10 +909,10 @@ class LayerNormMLP(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh for
the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
activation_params: dict, default = None
......@@ -903,7 +921,7 @@ class LayerNormMLP(TransformerEngineBase):
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
intermediate_dropout_rate: float, default = 0.0
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
......@@ -941,6 +959,8 @@ class LayerNormMLP(TransformerEngineBase):
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
quantization_checkpoint_name: Optional[str], default = None
The name for checkpointing quantizations.
"""
intermediate_dim: int = 2048
......@@ -959,11 +979,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
return_layernorm_output: bool = False
activations: Sequence[Union[str, Callable]] = ("gelu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_dropout_rate: float = 0.0
intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
......@@ -976,6 +996,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
quantization_checkpoint_name: Optional[str] = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -1010,8 +1031,12 @@ class LayerNormMLP(TransformerEngineBase):
"""
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
ffn1_quantizer_set = self.generate_quantizer_set(
"_0", quantization_checkpoint_name=self.quantization_checkpoint_name
)
ffn2_quantizer_set = self.generate_quantizer_set(
"_1", quantization_checkpoint_name=self.quantization_checkpoint_name
)
input_dtype = inputs.dtype
ln_output = None
......@@ -1019,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -1105,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
......@@ -1117,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......
......@@ -197,6 +197,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
fused_scale_factor = scale_factor
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias
bias = None
def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
......@@ -406,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable:
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
attention kernel is not available on the system, a warning will be issued, and the module
will automatically fall back to the unfused backend.
.. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced
......@@ -601,7 +602,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert bias is not None
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
# Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim]
......@@ -1618,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1
intermediate_dropout: float, default = 0.0
Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer.
......@@ -1633,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', )
mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
......@@ -1753,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
mlp_activation_params: dict = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
......
......@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
get_quantize_config,
)
......@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel
"""
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......@@ -236,8 +235,8 @@ def _layernorm_dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
x.shape,
kernel.shape,
mu,
......
......@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
TensorUsage,
get_quantize_config,
)
......@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
......@@ -390,11 +389,11 @@ def _layernorm_mlp_fwd_rule(
rsigma,
gamma,
beta,
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel),
dot_1_output,
casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel),
x_contracting_dims,
k_contracting_dims,
kernel_1.shape,
......
......@@ -17,3 +17,4 @@ from .metadata import *
from .hadamard import *
from .helper import *
from .device_utils import *
from .misc import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment