Unverified Commit 9440b76a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Shardy rule + QuantizeLayout Rework (#2364)



* shardy + quantize_layout rework
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add assertion for NVFP4 in fused act and fused norm primitive
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add assertions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ef28c865
......@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise
)
return (
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100
......@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
quantizer is not None
and quantizer.scaling_mode.is_tensor_scaling()
and quantizer.q_layout.is_rowwise_colwise
)
if not should_apply_war:
return None
......
......@@ -11,7 +11,7 @@ from typing import Optional, Union
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec
......@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive):
"Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling."
)
assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
"NVFP4 block scaling is not yet supported for fused norm and quantization."
" Please do norm in higher-precision then quantize with current tensor scaling."
)
assert (
not quantize_layout.is_colwise_only
), "Fused norm with colwise-only quantization is not supported."
mu_rsigama_dtype = jnp.float32
......@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive):
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
......@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive):
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
True, # is_training
)
wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
quantize_layout=quantize_layout.value.value,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
......@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if is_2x:
if quantize_layout.has_colwise:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
......@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive):
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
......@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive):
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
......@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
......@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive):
result_types,
)
prefix = "NormFwd_"
prefix = "NormFwd"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
value_types[0].shape,
unique_var=prefix,
flatten_axis=-1,
q_layout=quantize_layout,
)
x_axes = scale_rules.input_spec
input_spec = scale_rules.input_spec
out = x_axes
colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1]
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (prefix + "amax",)
rsigma = input_spec[:-1]
mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
gamma = (BATCHING + prefix + "_gamma",)
beta = (BATCHING + prefix + "_beta",)
return SdyShardingRule(
(x_axes, ("…1",), amax, ("…2",), ("…3",)),
(input_spec, scale, amax, gamma, beta),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
mu,
rsigma,
......@@ -987,7 +998,7 @@ def layernorm_fwd(
return (output, mu, rsigma)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1008,7 +1019,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=False,
......@@ -1067,10 +1078,11 @@ def layernorm_fwd(
)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1090,7 +1102,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1099,8 +1111,7 @@ def layernorm_fwd(
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1238,7 +1249,7 @@ def rmsnorm_fwd(
return (output, rsigma)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1261,7 +1272,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1321,10 +1332,11 @@ def rmsnorm_fwd(
)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1344,7 +1356,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -1353,8 +1365,7 @@ def rmsnorm_fwd(
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......
......@@ -11,7 +11,7 @@ import math
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -122,7 +122,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if QuantizeLayout(q_layout).has_rowwise:
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
......@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
broadcast_2d_scale_shape_to_1d=True,
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if QuantizeLayout(q_layout).has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
......@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(scale_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
q_layout.value,
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
post_rht_amax,
rht_matrix,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
q_layout=q_layout.value.value,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding,
......@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
......@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*x_spec),
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
......@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
......@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
......@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
......@@ -643,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
result_types,
)
prefix = "DBiasQuantize_"
prefix = "DBiasQuantize"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape,
unique_var=prefix + "x",
unique_var=prefix,
flatten_axis=flatten_axis,
q_layout=q_layout,
broadcast_2d_scale_shape_to_1d=True,
)
x_axes = scale_rules.input_spec
out = x_axes
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "colwise_scale_inv",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv = scale_rules.colwise_rule
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
colwise_scale_inv = tuple(
multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
)
else:
colwise_out = x_axes
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")
input_spec = scale_rules.input_spec
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
sr_rng_state = (
BATCHING + prefix + "_sr_rng_state_partition_axis",
BATCHING + prefix + "sr_rng_state_data_axis",
)
post_rht_amax = (prefix + "post_rht_amax",)
rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2")
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
return SdyShardingRule(
(x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
(input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
(
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
dbias,
),
**scale_rules.factor_sizes,
)
......@@ -762,7 +758,7 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
is_unsupported = quantizer.q_layout.is_colwise_only and not (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and hasattr(quantizer, "use_rht")
and quantizer.use_rht
......@@ -845,7 +841,7 @@ def _quantize_dbias_impl(
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x()
and quantizer.q_layout.is_rowwise_colwise
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout
......@@ -879,7 +875,7 @@ def _quantize_dbias_impl(
rht_matrix,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
......@@ -888,10 +884,10 @@ def _quantize_dbias_impl(
use_rht=use_rht,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE:
if q_layout.is_rowwise_only:
# Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0:
......@@ -1043,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
flatten_axis=flatten_axis,
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_rowwise:
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
......@@ -1052,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if q_layout.has_colwise:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
......@@ -1117,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale,
group_sizes,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
q_layout=q_layout.value.value,
flatten_axis=flatten_axis,
)
......@@ -1240,7 +1236,7 @@ def grouped_quantize(
)
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
(
rowwise_casted_output,
......@@ -1254,7 +1250,7 @@ def grouped_quantize(
group_sizes,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(),
......@@ -1262,7 +1258,7 @@ def grouped_quantize(
# For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
......
......@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
......@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout);
JAXX_Quantize_Layout quantize_layout);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
......@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
......@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2];
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
......@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
}
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
......@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
......@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
output_amax_when_no_scaling);
}
......@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
std::vector<size_t>{1});
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
......@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
......@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
}
}
if (is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
......@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
is_quantize_2x2x(quantize_layout) && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
......@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
......@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
act_params, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
......
......@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeLayout {
enum class JAXX_Quantize_Layout : int64_t {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
......
......@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
......@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
if (_is_2x) {
if (is_quantize_2x2x(quantize_layout)) {
output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(),
static_cast<DType>(out_dtype), input_shape);
output_tensor.set_columnwise_scale_inv(
......@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
Error_Type NormForwardInitializeFFI(
cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
scaling_mode, quantize_layout, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
......@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
......
......@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
pybind11::enum_<JAXX_Quantize_Layout>(m, "JAXX_Quantize_Layout", pybind11::module_local())
.value("ROWWISE", JAXX_Quantize_Layout::ROWWISE)
.value("COLWISE", JAXX_Quantize_Layout::COLWISE)
.value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE)
.export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
......
......@@ -20,7 +20,7 @@ namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
JAXX_Quantize_Layout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
......@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
if (is_quantize_rowwise(q_layout)) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
......@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
}
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
if (is_quantize_colwise(q_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
......@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
int64_t flatten_axis, bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data();
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
......@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_rowwise(quantize_layout)) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_tensor_scaling) {
......@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
quant_config.set_rng_state(sr_rng_state_tensor.data());
}
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_colwise(quantize_layout)) {
if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_quantize_2x2x(quantize_layout)) {
// Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
......@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
......@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
......@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
......@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
......@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0;
size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
......@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) {
if (is_quantize_rowwise(quantize_layout)) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) {
......@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
}
}
if (has_colwise) {
if (is_quantize_colwise(quantize_layout)) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
......@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<int64_t>("flatten_axis"));
} // namespace jax
......
......@@ -17,3 +17,4 @@ from .metadata import *
from .hadamard import *
from .helper import *
from .device_utils import *
from .misc import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This module provides additional enum and utilities for quantizing tensors in JAX.
"""
from dataclasses import dataclass
from enum import Enum
from transformer_engine_jax import JAXX_Quantize_Layout
__all__ = [
"QuantizeLayout",
]
@dataclass(frozen=True)
class QuantizeLayout(Enum):
"Wrapper for JAXX_Quantize_Layout"
ROWWISE = JAXX_Quantize_Layout.ROWWISE
COLWISE = JAXX_Quantize_Layout.COLWISE
ROWWISE_COLWISE = JAXX_Quantize_Layout.ROWWISE_COLWISE
@property
def has_rowwise(self) -> bool:
"""If the layout has the rowwise component"""
return self.value in (JAXX_Quantize_Layout.ROWWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE)
@property
def has_colwise(self) -> bool:
"""If the layout has the colwise component"""
return self.value in (JAXX_Quantize_Layout.COLWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE)
@property
def is_rowwise_colwise(self) -> bool:
"""If layout is both rowwise and colwise"""
return self.value == JAXX_Quantize_Layout.ROWWISE_COLWISE
@property
def is_rowwise_only(self) -> bool:
"""If layout is rowwise only"""
return self.value == JAXX_Quantize_Layout.ROWWISE
@property
def is_colwise_only(self) -> bool:
"""If layout is colwise only"""
return self.value == JAXX_Quantize_Layout.COLWISE
def __eq__(self, other):
"""Compare this quantize layout with another.
Args:
other: The other quantize layout to compare with
Returns:
True if the modes are equal, False otherwise
"""
if not isinstance(other, QuantizeLayout):
return False
return self.value == other.value
......@@ -15,10 +15,10 @@ import warnings
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .misc import QuantizeLayout
from .hadamard import apply_rht
from .tensor import (
ScaledTensor,
......@@ -37,7 +37,6 @@ from .device_utils import is_fp8_gemm_with_all_layouts_supported
from ..sharding import get_num_devices_in_mesh
__all__ = [
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"CurrentScaleQuantizer",
......@@ -118,14 +117,6 @@ class Quantizer(ABC):
"""Update quantizer state (no-op in base class)."""
del args, kwargs
def is_2x2x(self) -> bool:
"""Check if quantizer uses both row-wise and column-wise quantization.
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
def get_data_layout(self) -> str:
"""Get the data data_layout string.
......@@ -135,11 +126,11 @@ class Quantizer(ABC):
Raises:
ValueError: If quantization axis is invalid
"""
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
if self.q_layout.is_rowwise_colwise:
return self.data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
if self.q_layout.is_rowwise_only:
return self.data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
if self.q_layout.is_colwise_only:
return self.data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
......@@ -174,18 +165,10 @@ class Quantizer(ABC):
"""
del kwargs
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
if (is_rowwise and is_colwise) or self.is_2x2x():
if is_rowwise and is_colwise:
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
......@@ -299,16 +282,8 @@ class CurrentScaleQuantizer(Quantizer):
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None
......@@ -974,16 +949,8 @@ class GroupedQuantizer(Quantizer):
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
assert is_rowwise or is_colwise, "No quantization layout is specified"
original_shape = x.shape
......
......@@ -21,7 +21,8 @@ from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from transformer_engine_jax import JAXX_Scaling_Mode
from .misc import QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
......@@ -72,16 +73,18 @@ class QuantizeShardyRules:
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
rowwise_out_spec: Sharding spec for the rowwise quantized data
rowwise_scale_spec: Sharding spec for the rowwise scale
colwise_out_spec: Sharding spec for the colwise quantized data
colwise_scale_spec: Sharding spec for the colwise scale
factor_sizes: For block scaling, contains the block size factor
"""
input_spec: Tuple[str]
rowwise_rule: Tuple[str]
colwise_rule: Tuple[str]
rowwise_out_spec: Tuple[str]
rowwise_scale_spec: Tuple[str]
colwise_out_spec: Tuple[str]
colwise_scale_spec: Tuple[str]
factor_sizes: Dict[str, int]
......@@ -166,7 +169,9 @@ class ScalingModeMetadataImpl(ABC):
input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -174,7 +179,9 @@ class ScalingModeMetadataImpl(ABC):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
is_colwise_transposed: Whether the column-wise tensors are transposed.
Returns:
The Shardy rules for the scaling mode
......@@ -268,7 +275,9 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -281,10 +290,17 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
del broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape)))
output_spec = tuple(input_spec)
return QuantizeShardyRules(
input_spec,
output_spec,
(BATCHING + f"{unique_var}_scale",),
(BATCHING + f"{unique_var}_colwise_output",),
(BATCHING + f"{unique_var}_colwise_scale",),
{},
)
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
......@@ -376,7 +392,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -385,14 +403,26 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
q_layout: The layout of the quantized tensor
is_colwise_transposed: Whether the colwise scaling is transposed
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
del broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape)))
output_spec = input_spec
colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",)
if q_layout.has_colwise:
from ..cpp_extensions.misc import multidim_transpose
colwise_output_spec = input_spec
if is_colwise_transposed:
colwise_output_spec = multidim_transpose(
colwise_output_spec, transpose_axis=flatten_axis
)
scale = (BATCHING + unique_var + "_scale_inv",)
return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
......@@ -658,7 +688,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -666,15 +698,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
is_colwise_transposed: Whether the column-wise tensors are transposed.
Returns:
The Shardy rules for the scaling mode
"""
# TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed
is_rowwise = q_layout.has_rowwise
is_colwise = q_layout.has_colwise
input_rank = len(input_shape)
input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
flatten_axis = (flatten_axis + input_rank) % input_rank
input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)]
assert (
self._block_dims[1] != 1
......@@ -690,30 +725,56 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
# No CompoundFactor is needed if the dim has the same size as the blocksize
blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == block_size_1d:
colwise_var = f"{unique_var}_None"
if is_rowwise and not input_shape[-1] == block_size_1d:
rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = block_size_1d
if not input_shape[flatten_axis - 1] == block_size_1d:
if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d:
colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = block_size_1d
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
if is_rowwise:
rowwise_out = input_spec.copy()
rowwise_scale = input_spec.copy()
rowwise_scale[-1] = rowwise_var
else:
rowwise_out = [
BATCHING + f"{unique_var}_rowwise_output",
]
rowwise_scale = [
BATCHING + f"{unique_var}_rowwise_scale_inv",
]
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
if is_colwise:
colwise_out = input_spec.copy()
colwise_scale = input_spec.copy()
colwise_scale[flatten_axis - 1] = colwise_var
if is_colwise_transposed:
from ..cpp_extensions.misc import multidim_transpose
colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis)
colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis)
else:
colwise_out = [
BATCHING + f"{unique_var}_colwise_output",
]
colwise_scale = [
BATCHING + f"{unique_var}_colwise_scale_inv",
]
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
tuple(rowwise_out),
tuple(rowwise_scale),
tuple(colwise_out),
tuple(colwise_scale),
blocksizes,
)
......@@ -850,7 +911,8 @@ class ScalingMode(Enum):
self,
input_shape,
unique_var,
flatten_axis=-1,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -859,13 +921,19 @@ class ScalingMode(Enum):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(
input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d
input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
self.is_colwise_transposed,
)
def get_grouped_scale_shape_2x(
......
......@@ -15,10 +15,10 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap
from .misc import QuantizeLayout
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
......@@ -128,9 +128,7 @@ class NoScaleTensor(AbstractBaseTensor1x):
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
assert (
q_layout == QuantizeLayout.ROWWISE
), "Only ROWWISE layout is supported for NoScaleTensor"
assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor"
return self
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
......@@ -264,8 +262,8 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage)
colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise
rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise
if colwise_usage_valid or rowwise_usage_valid:
return self
......@@ -301,16 +299,15 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !?
if self.scaling_mode.is_block_scaling: # Both MXFP8 and NVFP4
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else:
scale_inv = self.scale_inv
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
amax=self.amax,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
......@@ -467,10 +464,10 @@ class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
if q_layout_rowwise == QuantizeLayout.ROWWISE:
if q_layout_rowwise.is_rowwise_only:
return self.rowwise_tensor
if q_layout_colwise == QuantizeLayout.COLWISE:
if q_layout_colwise.is_colwise_only:
return self.colwise_tensor
raise ValueError(
......@@ -548,13 +545,13 @@ class ScaledTensorFactory:
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape)
assert (
original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors
group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
group_axis = (len(original_shape) + group_axis) % len(original_shape)
if data_layout == "T":
if original_shape[0] == group_sizes.size:
original_shape = (
......@@ -587,7 +584,7 @@ class ScaledTensorFactory:
)
# Handling attrs of transposed tensors
flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
flatten_axis = (data.ndim + flatten_axis) % data.ndim
if data_layout == "T":
flatten_axis = data.ndim - flatten_axis
......@@ -669,7 +666,7 @@ class ScaledTensorFactory:
colwise_amax,
scaling_mode,
dq_dtype,
is_colwise=True, # TODO(Phuong): set this correctly
is_colwise=True,
data_layout=data_layout[1],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
......@@ -721,7 +718,7 @@ class ScaledTensorFactory:
"""
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE:
if q_layout.is_rowwise_colwise:
return ScaledTensorFactory.create_2x(
data,
scale_inv,
......@@ -740,15 +737,14 @@ class ScaledTensorFactory:
colwise_has_rht_applied=colwise_has_rht_applied,
)
is_colwise = q_layout == QuantizeLayout.COLWISE
if is_colwise:
if q_layout.is_colwise_only:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
colwise_amax if colwise_amax is not None else amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
is_colwise=True,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
......@@ -763,7 +759,7 @@ class ScaledTensorFactory:
amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
is_colwise=False,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment