Unverified Commit 9440b76a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Shardy rule + QuantizeLayout Rework (#2364)



* shardy + quantize_layout rework
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add assertion for NVFP4 in fused act and fused norm primitive
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add assertions
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ef28c865
...@@ -10,7 +10,7 @@ from dataclasses import dataclass ...@@ -10,7 +10,7 @@ from dataclasses import dataclass
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import numpy as np import numpy as np
...@@ -159,7 +159,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -159,7 +159,7 @@ class ActLuPrimitive(BasePrimitive):
11, 11,
12, 12,
13, 13,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer ) # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -173,7 +173,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -173,7 +173,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -201,6 +201,13 @@ class ActLuPrimitive(BasePrimitive): ...@@ -201,6 +201,13 @@ class ActLuPrimitive(BasePrimitive):
"Current tensor scaling is not yet supported for fused activation and quantization." "Current tensor scaling is not yet supported for fused activation and quantization."
" Please do activation in higher-precision then quantize with current tensor scaling." " Please do activation in higher-precision then quantize with current tensor scaling."
) )
assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
"NVFP4 block scaling is not yet supported for fused activation and quantization."
" Please do activation in higher-precision then quantize with current tensor scaling."
)
assert (
not quantize_layout.is_colwise_only
), "Fused activation with colwise-only quantization is not supported."
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
...@@ -210,7 +217,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -210,7 +217,7 @@ class ActLuPrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
if not is_2x: if quantize_layout.is_rowwise_only:
out_shape = (1,) out_shape = (1,)
colwise_scale_inv_shape = (1,) colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
...@@ -232,7 +239,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -232,7 +239,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -259,7 +266,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -259,7 +266,7 @@ class ActLuPrimitive(BasePrimitive):
amax, amax,
act_enum=act_enum, act_enum=act_enum,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, quantize_layout=quantize_layout.value.value,
act_params=act_params.to_ffi_lowering_dict(), act_params=act_params.to_ffi_lowering_dict(),
output_amax_when_no_scaling=output_amax_when_no_scaling, output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
...@@ -274,7 +281,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -274,7 +281,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -297,7 +304,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -297,7 +304,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -313,7 +320,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -313,7 +320,7 @@ class ActLuPrimitive(BasePrimitive):
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
if is_2x: if quantize_layout.is_rowwise_colwise:
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
...@@ -329,7 +336,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -329,7 +336,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -356,7 +363,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -356,7 +363,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -373,7 +380,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -373,7 +380,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -402,7 +409,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -402,7 +409,7 @@ class ActLuPrimitive(BasePrimitive):
out_spec = (*x_spec[:-2], x_spec[-1]) out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if quantize_layout.is_rowwise_colwise:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
...@@ -419,7 +426,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -419,7 +426,7 @@ class ActLuPrimitive(BasePrimitive):
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
if is_2x: if quantize_layout.is_rowwise_colwise:
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
...@@ -444,7 +451,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -444,7 +451,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -462,7 +469,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -462,7 +469,7 @@ class ActLuPrimitive(BasePrimitive):
out_spec = (*x_spec[:-2], x_spec[-1]) out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if quantize_layout.is_rowwise_colwise:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
...@@ -479,7 +486,10 @@ class ActLuPrimitive(BasePrimitive): ...@@ -479,7 +486,10 @@ class ActLuPrimitive(BasePrimitive):
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
if is_2x: if quantize_layout.is_rowwise_colwise:
assert not ScalingMode(
scaling_mode
).is_colwise_transposed, "Transpose layout scaling modes are not supported here yet"
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
...@@ -514,7 +524,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -514,7 +524,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -550,7 +560,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -550,7 +560,7 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
act_params, act_params,
amax_scope, amax_scope,
...@@ -574,37 +584,28 @@ class ActLuPrimitive(BasePrimitive): ...@@ -574,37 +584,28 @@ class ActLuPrimitive(BasePrimitive):
mesh, mesh,
result_types, result_types,
) )
prefix = "ActLu_" prefix = "ActLu"
input_shape = value_types[0].shape input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:] output_shape = input_shape[:-2] + input_shape[-1:]
# Here we pass len of output so that the scales are propagated correctly # Here we pass len of output so that the scales are propagated correctly
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
output_shape, unique_var=prefix + "x", flatten_axis=-1 output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout
) )
x_axes = scale_rules.input_spec # Correct the input spec with act dim
# Correct input spec with act dim input_spec = scale_rules.input_spec
x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:]
out = scale_rules.input_spec amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = multidim_transpose(out, transpose_axis=-1)
else:
colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(tuple(input_spec), scale, amax),
( (
x_axes, scale_rules.rowwise_out_spec,
("…1",), scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax, amax,
), ),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
) )
...@@ -612,7 +613,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -612,7 +613,6 @@ class ActLuPrimitive(BasePrimitive):
register_primitive(ActLuPrimitive) register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
DActLu DBias Cast Transpose Primitive DActLu DBias Cast Transpose Primitive
...@@ -620,7 +620,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -620,7 +620,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi" name = "te_dact_dbias_quantize_ffi"
multiple_results = True multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -634,7 +634,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -634,7 +634,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -678,7 +678,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -678,7 +678,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x: if quantize_layout.is_rowwise_colwise:
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else: else:
...@@ -700,7 +700,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -700,7 +700,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
scaling_mode, scaling_mode,
is_2x, quantize_layout.value,
) )
wkspace_shape = wkspace_info[0] wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -741,7 +741,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -741,7 +741,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -777,7 +777,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -777,7 +777,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
scale, scale,
amax, amax,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, quantize_layout=quantize_layout.value.value,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=int(act_enum), act_enum=int(act_enum),
act_params=act_params.to_ffi_lowering_dict(), act_params=act_params.to_ffi_lowering_dict(),
...@@ -792,7 +792,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -792,7 +792,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax, amax,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -816,7 +816,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -816,7 +816,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax, amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
...@@ -835,7 +835,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -835,7 +835,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
if is_2x: if quantize_layout.is_rowwise_colwise:
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
...@@ -848,7 +848,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -848,7 +848,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -883,7 +883,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -883,7 +883,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax, amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
...@@ -901,7 +901,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -901,7 +901,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -928,7 +928,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -928,7 +928,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if quantize_layout.is_rowwise_colwise:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
...@@ -954,7 +954,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -954,7 +954,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if is_2x: if quantize_layout.is_rowwise_colwise:
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
...@@ -981,7 +981,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -981,7 +981,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -1003,7 +1003,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1003,7 +1003,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if quantize_layout.is_rowwise_colwise:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
...@@ -1029,7 +1029,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1029,7 +1029,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if is_2x: if quantize_layout.is_rowwise_colwise:
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
...@@ -1066,7 +1066,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1066,7 +1066,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax, amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
...@@ -1102,7 +1102,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1102,7 +1102,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
def shardy_sharding_rule( def shardy_sharding_rule(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
is_dbias, is_dbias,
act_enum, act_enum,
...@@ -1132,28 +1132,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1132,28 +1132,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
) )
prefix = "DActLuDBias_" prefix = "DActLuDBias_"
# get sharding rules base on the input shape
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 value_types[1].shape,
unique_var=prefix,
flatten_axis=-2,
q_layout=quantize_layout,
) )
x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes
colwise_out = (prefix + "out_colwise",) input_spec = scale_rules.input_spec
colwise_scale_inv = (prefix + "scale_inv_colwise",) dz_spec = (*input_spec[:-2], input_spec[-1])
if is_2x: dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: amax = (prefix + "_amax",)
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) scale = (prefix + "_scale",)
else:
colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(dz_axes, x_axes, ("…2",), amax), (tuple(dz_spec), tuple(input_spec), scale, amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
dbias,
),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
) )
...@@ -1269,7 +1271,7 @@ def act_lu( ...@@ -1269,7 +1271,7 @@ def act_lu(
return _jax_act_lu(x, activation_type, quantizer, act_params) return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_act_lu(x, activation_type, quantizer, act_params) return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support 2x quantization for DelayedScaling yet # TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war( war_output = try_apply_delayed_scaling_2x_war(
...@@ -1298,7 +1300,7 @@ def act_lu( ...@@ -1298,7 +1300,7 @@ def act_lu(
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -1354,7 +1356,7 @@ def act_lu( ...@@ -1354,7 +1356,7 @@ def act_lu(
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), quantize_layout=quantizer.q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
act_params=act_params, act_params=act_params,
amax_scope=amax_scope, amax_scope=amax_scope,
...@@ -1415,7 +1417,7 @@ def quantize_dact_dbias( ...@@ -1415,7 +1417,7 @@ def quantize_dact_dbias(
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled() or ( if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE quantizer is not None and quantizer.q_layout.is_colwise_only
): ):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
if quantizer is None: if quantizer is None:
...@@ -1428,7 +1430,7 @@ def quantize_dact_dbias( ...@@ -1428,7 +1430,7 @@ def quantize_dact_dbias(
out_dtype=(jnp.float32 if is_dbias else x.dtype), out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset # default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused quantize_layout=QuantizeLayout.ROWWISE, # unused
scale_dtype=jnp.float32, # unused scale_dtype=jnp.float32, # unused
is_dbias=False, is_dbias=False,
act_enum=act_type_id, act_enum=act_type_id,
...@@ -1555,7 +1557,7 @@ def quantize_dact_dbias( ...@@ -1555,7 +1557,7 @@ def quantize_dact_dbias(
amax, amax,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), quantize_layout=quantizer.q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_type_id, act_enum=act_type_id,
...@@ -1568,7 +1570,7 @@ def quantize_dact_dbias( ...@@ -1568,7 +1570,7 @@ def quantize_dact_dbias(
) )
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax) quantizer.update(updated_amax)
......
...@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100. # but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise
)
return ( return (
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100 and arch_l_100
...@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, ...@@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated @return: the output of 'f' with the colwise output calculated
""" """
should_apply_war = ( should_apply_war = (
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() quantizer is not None
and quantizer.scaling_mode.is_tensor_scaling()
and quantizer.q_layout.is_rowwise_colwise
) )
if not should_apply_war: if not should_apply_war:
return None return None
......
...@@ -11,7 +11,7 @@ from typing import Optional, Union ...@@ -11,7 +11,7 @@ from typing import Optional, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -112,7 +112,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -148,6 +148,13 @@ class NormFwdPrimitive(BasePrimitive):
"Current tensor scaling is not supported for fused norm and quantization. Please do" "Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling." " norm in higher-precision then quantize with current tensor scaling."
) )
assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
"NVFP4 block scaling is not yet supported for fused norm and quantization."
" Please do norm in higher-precision then quantize with current tensor scaling."
)
assert (
not quantize_layout.is_colwise_only
), "Fused norm with colwise-only quantization is not supported."
mu_rsigama_dtype = jnp.float32 mu_rsigama_dtype = jnp.float32
...@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -165,7 +172,7 @@ class NormFwdPrimitive(BasePrimitive):
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,) colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -173,7 +180,7 @@ class NormFwdPrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray( colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype shape=colwise_scale_inv_shape, dtype=scale_dtype
) )
...@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -189,7 +196,7 @@ class NormFwdPrimitive(BasePrimitive):
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
get_forward_sm_margin(), get_forward_sm_margin(),
is_2x, True, # is_training
) )
wkspace_aval = jax.core.ShapedArray( wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -245,7 +252,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -287,7 +294,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
sm_margin=sm_margin, sm_margin=sm_margin,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, quantize_layout=quantize_layout.value.value,
output_amax_when_no_scaling=output_amax_when_no_scaling, output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
...@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -303,7 +310,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -335,7 +342,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -349,7 +356,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape rowwise_scale_inv_shape
) )
if is_2x: if quantize_layout.has_colwise:
colwise_scale_inv = colwise_scale_inv.flatten()[ colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1) : reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape) ].reshape(colwise_scale_inv_shape)
...@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -373,7 +380,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -409,7 +416,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -426,7 +433,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -450,7 +457,7 @@ class NormFwdPrimitive(BasePrimitive):
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,) colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
...@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -488,7 +495,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -524,7 +531,7 @@ class NormFwdPrimitive(BasePrimitive):
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,) colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
...@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -586,7 +593,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon=epsilon, epsilon=epsilon,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, quantize_layout=quantize_layout,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -623,7 +630,7 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, quantize_layout,
scale_dtype, scale_dtype,
amax_scope, amax_scope,
transpose_batch_sequence, transpose_batch_sequence,
...@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -646,25 +653,29 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwd_" prefix = "NormFwd"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 value_types[0].shape,
unique_var=prefix,
flatten_axis=-1,
q_layout=quantize_layout,
) )
x_axes = scale_rules.input_spec input_spec = scale_rules.input_spec
out = x_axes rsigma = input_spec[:-1]
colwise_out = out if is_2x else (prefix + "out_colwise",) mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
rsigma = x_axes[:-1] amax = (BATCHING + prefix + "_amax",)
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma scale = (BATCHING + prefix + "_scale",)
amax = (prefix + "amax",) gamma = (BATCHING + prefix + "_gamma",)
beta = (BATCHING + prefix + "_beta",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax, ("…2",), ("…3",)), (input_spec, scale, amax, gamma, beta),
( (
out, scale_rules.rowwise_out_spec,
colwise_out, scale_rules.colwise_out_spec,
scale_rules.rowwise_rule, scale_rules.rowwise_scale_spec,
scale_rules.colwise_rule, scale_rules.colwise_scale_spec,
amax, amax,
mu, mu,
rsigma, rsigma,
...@@ -987,7 +998,7 @@ def layernorm_fwd( ...@@ -987,7 +998,7 @@ def layernorm_fwd(
return (output, mu, rsigma) return (output, mu, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1008,7 +1019,7 @@ def layernorm_fwd( ...@@ -1008,7 +1019,7 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=False, transpose_batch_sequence=False,
...@@ -1067,10 +1078,11 @@ def layernorm_fwd( ...@@ -1067,10 +1078,11 @@ def layernorm_fwd(
) )
return out, mu, rsigma return out, mu, rsigma
is_2x2x = quantizer.is_2x2x() # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common normalization doesn't support 2x delayed scaling q_layout = quantizer.q_layout
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False q_layout = QuantizeLayout.ROWWISE
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -1090,7 +1102,7 @@ def layernorm_fwd( ...@@ -1090,7 +1102,7 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1099,8 +1111,7 @@ def layernorm_fwd( ...@@ -1099,8 +1111,7 @@ def layernorm_fwd(
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -1238,7 +1249,7 @@ def rmsnorm_fwd( ...@@ -1238,7 +1249,7 @@ def rmsnorm_fwd(
return (output, rsigma) return (output, rsigma)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1261,7 +1272,7 @@ def rmsnorm_fwd( ...@@ -1261,7 +1272,7 @@ def rmsnorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1321,10 +1332,11 @@ def rmsnorm_fwd( ...@@ -1321,10 +1332,11 @@ def rmsnorm_fwd(
) )
return out, rsigma return out, rsigma
is_2x2x = quantizer.is_2x2x() # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common normalization doesn't support 2x delayed scaling q_layout = quantizer.q_layout
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False q_layout = QuantizeLayout.ROWWISE
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -1344,7 +1356,7 @@ def rmsnorm_fwd( ...@@ -1344,7 +1356,7 @@ def rmsnorm_fwd(
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
...@@ -1353,8 +1365,7 @@ def rmsnorm_fwd( ...@@ -1353,8 +1365,7 @@ def rmsnorm_fwd(
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
......
...@@ -11,7 +11,7 @@ import math ...@@ -11,7 +11,7 @@ import math
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -122,7 +122,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -122,7 +122,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
) )
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if QuantizeLayout(q_layout).has_rowwise:
rowwise_out_shape = out_shape rowwise_out_shape = out_shape
else: else:
rowwise_out_shape = (1,) rowwise_out_shape = (1,)
...@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -170,7 +170,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
broadcast_2d_scale_shape_to_1d=True, broadcast_2d_scale_shape_to_1d=True,
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if QuantizeLayout(q_layout).has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else: else:
...@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -194,9 +194,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(scale_dtype), jax_dtype_to_te_dtype(scale_dtype),
scaling_mode, scaling_mode,
QuantizeLayout( q_layout.value,
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
) )
wkspace_shape = wkspace_info[0] wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -272,7 +270,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
post_rht_amax, post_rht_amax,
rht_matrix, rht_matrix,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout.value.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
is_dbias=is_dbias, is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding, stochastic_rounding=stochastic_rounding,
...@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -335,7 +333,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
...@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -424,7 +422,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
desc="BaseDBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
...@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -448,7 +446,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling: if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ( if (
ScalingMode(scaling_mode).is_block_scaling ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed and ScalingMode(scaling_mode).is_colwise_transposed
...@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -505,7 +503,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ScalingMode(scaling_mode).is_colwise_transposed: if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
...@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -529,7 +527,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
if ScalingMode(scaling_mode).is_block_scaling: if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
if ( if (
ScalingMode(scaling_mode).is_block_scaling ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed and ScalingMode(scaling_mode).is_colwise_transposed
...@@ -643,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -643,39 +641,37 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "DBiasQuantize_" prefix = "DBiasQuantize"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, value_types[0].shape,
unique_var=prefix + "x", unique_var=prefix,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
q_layout=q_layout,
broadcast_2d_scale_shape_to_1d=True, broadcast_2d_scale_shape_to_1d=True,
) )
x_axes = scale_rules.input_spec input_spec = scale_rules.input_spec
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
out = x_axes amax = (BATCHING + prefix + "_amax",)
colwise_out = (prefix + "out_colwise",) scale = (BATCHING + prefix + "_scale",)
colwise_scale_inv = (prefix + "colwise_scale_inv",) sr_rng_state = (
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): BATCHING + prefix + "_sr_rng_state_partition_axis",
colwise_scale_inv = scale_rules.colwise_rule BATCHING + prefix + "sr_rng_state_data_axis",
if ScalingMode(scaling_mode).is_colwise_transposed: )
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
colwise_scale_inv = tuple(
multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
)
else:
colwise_out = x_axes
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")
post_rht_amax = (prefix + "post_rht_amax",) post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
dbias,
),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
) )
...@@ -762,7 +758,7 @@ def _quantize_dbias_impl( ...@@ -762,7 +758,7 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not ( is_unsupported = quantizer.q_layout.is_colwise_only and not (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and hasattr(quantizer, "use_rht") and hasattr(quantizer, "use_rht")
and quantizer.use_rht and quantizer.use_rht
...@@ -845,7 +841,7 @@ def _quantize_dbias_impl( ...@@ -845,7 +841,7 @@ def _quantize_dbias_impl(
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x() and quantizer.q_layout.is_rowwise_colwise
and is_1x_kernel_supported and is_1x_kernel_supported
) )
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
...@@ -879,7 +875,7 @@ def _quantize_dbias_impl( ...@@ -879,7 +875,7 @@ def _quantize_dbias_impl(
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
...@@ -888,10 +884,10 @@ def _quantize_dbias_impl( ...@@ -888,10 +884,10 @@ def _quantize_dbias_impl(
use_rht=use_rht, use_rht=use_rht,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE: if q_layout.is_rowwise_only:
# Quantizer requires 2x quantization, but we are using 1x quantization # Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX # for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0: if flatten_axis < 0:
...@@ -1043,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1043,7 +1039,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_rowwise:
rowwise_out_shape = out_shape rowwise_out_shape = out_shape
else: else:
rowwise_out_shape = (1,) rowwise_out_shape = (1,)
...@@ -1052,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1052,7 +1048,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout.has_colwise:
colwise_out_shape = out_shape colwise_out_shape = out_shape
else: else:
colwise_out_shape = (1,) colwise_out_shape = (1,)
...@@ -1117,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -1117,7 +1113,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale, scale,
group_sizes, group_sizes,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout.value.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -1240,7 +1236,7 @@ def grouped_quantize( ...@@ -1240,7 +1236,7 @@ def grouped_quantize(
) )
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -1254,7 +1250,7 @@ def grouped_quantize( ...@@ -1254,7 +1250,7 @@ def grouped_quantize(
group_sizes, group_sizes,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_axis=group_axis, group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
...@@ -1262,7 +1258,7 @@ def grouped_quantize( ...@@ -1262,7 +1258,7 @@ def grouped_quantize(
# For DelayedScaling2x and CurrentScaling2x, the scale buffer # For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise # is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
......
...@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); ...@@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x); JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
...@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); ...@@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout); JAXX_Quantize_Layout quantize_layout);
// Softmax // Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
...@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ...@@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
...@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
...@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back(); auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)}; auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
...@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
} }
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape ? output_trans_shape
: output_shape; : output_shape;
...@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // updated_amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
...@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer ...@@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer
Buffer_Type amax_buf, Result_Type output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
ActivationConfig act_params, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
output_amax_when_no_scaling); output_amax_when_no_scaling);
} }
...@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // updated_amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size}; auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape; : output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
...@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_dbias, ActivationConfig act_params, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool output_amax_when_no_scaling) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
...@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
} }
} }
if (is_2x) { if (is_quantize_2x2x(quantize_layout)) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape ? output_trans_shape
: output_shape; : output_shape;
...@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
is_quantize_2x2x(quantize_layout) && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations."); "TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
...@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
...@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI( ...@@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool output_amax_when_no_scaling) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
output_amax_when_no_scaling); act_params, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params") .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
......
...@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,12 +34,24 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeLayout { enum class JAXX_Quantize_Layout : int64_t {
ROWWISE, ROWWISE,
COLWISE, COLWISE,
ROWWISE_COLWISE, ROWWISE_COLWISE,
}; };
inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) {
return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}
enum class JAXX_Scaling_Mode : int64_t { enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0, NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1, DELAYED_TENSOR_SCALING = 1,
......
...@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) { JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
auto x_size = product(x_buf.dimensions()); auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions()); auto gamma_size = product(gamma_buf.dimensions());
...@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
} }
if (_is_2x) { if (is_quantize_2x2x(quantize_layout)) {
output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(), output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(),
static_cast<DType>(out_dtype), input_shape); static_cast<DType>(out_dtype), input_shape);
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
...@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling"), .Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardInitializeFFI(
Buffer_Type amax_buf, Buffer_Type gamma_buf, cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf,
Buffer_Type beta_buf, Result_Type output_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon,
Result_Type wkspace_buf, int norm_type, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
bool zero_centered_gamma, double epsilon, int64_t sm_margin, bool output_amax_when_no_scaling) {
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling); scaling_mode, quantize_layout, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
...@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<JAXX_Quantize_Layout>("quantize_layout")
.Attr<bool>("output_amax_when_no_scaling")); .Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
......
...@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", pybind11::enum_<JAXX_Quantize_Layout>(m, "JAXX_Quantize_Layout", pybind11::module_local())
pybind11::module_local()) .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE)
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) .value("COLWISE", JAXX_Quantize_Layout::COLWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local()) pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
......
...@@ -20,7 +20,7 @@ namespace jax { ...@@ -20,7 +20,7 @@ namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType scale_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) { JAXX_Quantize_Layout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
...@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1}; auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { if (is_quantize_rowwise(q_layout)) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4) if (is_nvfp4)
...@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
} }
} }
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { if (is_quantize_colwise(q_layout)) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape; : output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
...@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool stochastic_rounding, bool use_rht) { int64_t flatten_axis, bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data();
...@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE || if (is_quantize_rowwise(quantize_layout)) {
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_tensor_scaling) { if (is_tensor_scaling) {
...@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
quant_config.set_rng_state(sr_rng_state_tensor.data()); quant_config.set_rng_state(sr_rng_state_tensor.data());
} }
if (quantize_layout == QuantizeLayout::COLWISE || if (is_quantize_colwise(quantize_layout)) {
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
if (is_nvfp4 && use_rht) { if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { if (is_quantize_2x2x(quantize_layout)) {
// Do regular rowwise quantization without RHT // Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
} }
...@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis") .Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding") .Attr<bool>("stochastic_rounding")
...@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
Buffer_Type group_sizes, Result_Type outputs, Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs, Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout,
int64_t flatten_axis) { int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode)); "Unsupported scaling mode: ", static_cast<int>(scaling_mode));
...@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data()); auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data()); auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
...@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data()); auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data()); auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
...@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
...@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype); auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) { if (is_quantize_rowwise(quantize_layout)) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i); out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
} }
} }
if (has_colwise) { if (is_quantize_colwise(quantize_layout)) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape); out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
...@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, ...@@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<JAXX_Quantize_Layout>("q_layout")
.Attr<int64_t>("flatten_axis")); .Attr<int64_t>("flatten_axis"));
} // namespace jax } // namespace jax
......
...@@ -17,3 +17,4 @@ from .metadata import * ...@@ -17,3 +17,4 @@ from .metadata import *
from .hadamard import * from .hadamard import *
from .helper import * from .helper import *
from .device_utils import * from .device_utils import *
from .misc import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This module provides additional enum and utilities for quantizing tensors in JAX.
"""
from dataclasses import dataclass
from enum import Enum
from transformer_engine_jax import JAXX_Quantize_Layout
__all__ = [
"QuantizeLayout",
]
@dataclass(frozen=True)
class QuantizeLayout(Enum):
"Wrapper for JAXX_Quantize_Layout"
ROWWISE = JAXX_Quantize_Layout.ROWWISE
COLWISE = JAXX_Quantize_Layout.COLWISE
ROWWISE_COLWISE = JAXX_Quantize_Layout.ROWWISE_COLWISE
@property
def has_rowwise(self) -> bool:
"""If the layout has the rowwise component"""
return self.value in (JAXX_Quantize_Layout.ROWWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE)
@property
def has_colwise(self) -> bool:
"""If the layout has the colwise component"""
return self.value in (JAXX_Quantize_Layout.COLWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE)
@property
def is_rowwise_colwise(self) -> bool:
"""If layout is both rowwise and colwise"""
return self.value == JAXX_Quantize_Layout.ROWWISE_COLWISE
@property
def is_rowwise_only(self) -> bool:
"""If layout is rowwise only"""
return self.value == JAXX_Quantize_Layout.ROWWISE
@property
def is_colwise_only(self) -> bool:
"""If layout is colwise only"""
return self.value == JAXX_Quantize_Layout.COLWISE
def __eq__(self, other):
"""Compare this quantize layout with another.
Args:
other: The other quantize layout to compare with
Returns:
True if the modes are equal, False otherwise
"""
if not isinstance(other, QuantizeLayout):
return False
return self.value == other.value
...@@ -15,10 +15,10 @@ import warnings ...@@ -15,10 +15,10 @@ import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .misc import QuantizeLayout
from .hadamard import apply_rht from .hadamard import apply_rht
from .tensor import ( from .tensor import (
ScaledTensor, ScaledTensor,
...@@ -37,7 +37,6 @@ from .device_utils import is_fp8_gemm_with_all_layouts_supported ...@@ -37,7 +37,6 @@ from .device_utils import is_fp8_gemm_with_all_layouts_supported
from ..sharding import get_num_devices_in_mesh from ..sharding import get_num_devices_in_mesh
__all__ = [ __all__ = [
"QuantizeLayout",
"Quantizer", "Quantizer",
"QuantizerSet", "QuantizerSet",
"CurrentScaleQuantizer", "CurrentScaleQuantizer",
...@@ -118,14 +117,6 @@ class Quantizer(ABC): ...@@ -118,14 +117,6 @@ class Quantizer(ABC):
"""Update quantizer state (no-op in base class).""" """Update quantizer state (no-op in base class)."""
del args, kwargs del args, kwargs
def is_2x2x(self) -> bool:
"""Check if quantizer uses both row-wise and column-wise quantization.
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
def get_data_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data data_layout string. """Get the data data_layout string.
...@@ -135,11 +126,11 @@ class Quantizer(ABC): ...@@ -135,11 +126,11 @@ class Quantizer(ABC):
Raises: Raises:
ValueError: If quantization axis is invalid ValueError: If quantization axis is invalid
""" """
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: if self.q_layout.is_rowwise_colwise:
return self.data_layout return self.data_layout
if self.q_layout == QuantizeLayout.ROWWISE: if self.q_layout.is_rowwise_only:
return self.data_layout[0] return self.data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE: if self.q_layout.is_colwise_only:
return self.data_layout[1] return self.data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}") raise ValueError(f"Invalid q_layout: {self.q_layout}")
...@@ -174,18 +165,10 @@ class Quantizer(ABC): ...@@ -174,18 +165,10 @@ class Quantizer(ABC):
""" """
del kwargs del kwargs
is_rowwise = ( is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_rowwise is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
if (is_rowwise and is_colwise) or self.is_2x2x(): if is_rowwise and is_colwise:
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func( colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
...@@ -299,16 +282,8 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -299,16 +282,8 @@ class CurrentScaleQuantizer(Quantizer):
flatten_axis += x.ndim flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = ( is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_rowwise is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None colwise_tensor = None
...@@ -974,16 +949,8 @@ class GroupedQuantizer(Quantizer): ...@@ -974,16 +949,8 @@ class GroupedQuantizer(Quantizer):
flatten_axis += x.ndim flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = ( is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
is_rowwise is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
assert is_rowwise or is_colwise, "No quantization layout is specified" assert is_rowwise or is_colwise, "No quantization layout is specified"
original_shape = x.shape original_shape = x.shape
......
...@@ -21,7 +21,8 @@ from jax.experimental.custom_partitioning import BATCHING, CompoundFactor ...@@ -21,7 +21,8 @@ from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout from transformer_engine_jax import JAXX_Scaling_Mode
from .misc import QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported from .device_utils import is_fp8_gemm_with_all_layouts_supported
...@@ -72,16 +73,18 @@ class QuantizeShardyRules: ...@@ -72,16 +73,18 @@ class QuantizeShardyRules:
Attributes: Attributes:
input_spec: Specification for the input axes input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on rowwise_out_spec: Sharding spec for the rowwise quantized data
the axes in `input_spec` rowwise_scale_spec: Sharding spec for the rowwise scale
colwise_rule: Likewise for the column-wise scale tensor. colwise_out_spec: Sharding spec for the colwise quantized data
factor_sizes: For block scaling, contains the block size factor, which is colwise_scale_spec: Sharding spec for the colwise scale
used in `input_spec`. factor_sizes: For block scaling, contains the block size factor
""" """
input_spec: Tuple[str] input_spec: Tuple[str]
rowwise_rule: Tuple[str] rowwise_out_spec: Tuple[str]
colwise_rule: Tuple[str] rowwise_scale_spec: Tuple[str]
colwise_out_spec: Tuple[str]
colwise_scale_spec: Tuple[str]
factor_sizes: Dict[str, int] factor_sizes: Dict[str, int]
...@@ -166,7 +169,9 @@ class ScalingModeMetadataImpl(ABC): ...@@ -166,7 +169,9 @@ class ScalingModeMetadataImpl(ABC):
input_shape, input_shape,
unique_var, unique_var,
flatten_axis, flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d, broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
...@@ -174,7 +179,9 @@ class ScalingModeMetadataImpl(ABC): ...@@ -174,7 +179,9 @@ class ScalingModeMetadataImpl(ABC):
input_shape: The shape of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization flatten_axis: Axis along which data can be flattened to 2D for quantization
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
is_colwise_transposed: Whether the column-wise tensors are transposed.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
...@@ -268,7 +275,9 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -268,7 +275,9 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape, input_shape,
unique_var, unique_var,
flatten_axis, flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d, broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
...@@ -281,10 +290,17 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -281,10 +290,17 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis, broadcast_2d_scale_shape_to_1d del broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" output_spec = tuple(input_spec)
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(
input_spec,
output_spec,
(BATCHING + f"{unique_var}_scale",),
(BATCHING + f"{unique_var}_colwise_output",),
(BATCHING + f"{unique_var}_colwise_scale",),
{},
)
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
...@@ -376,7 +392,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -376,7 +392,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape, input_shape,
unique_var, unique_var,
flatten_axis, flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d, broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
...@@ -385,14 +403,26 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -385,14 +403,26 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
q_layout: The layout of the quantized tensor
is_colwise_transposed: Whether the colwise scaling is transposed
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis, broadcast_2d_scale_shape_to_1d del broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" output_spec = input_spec
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",)
if q_layout.has_colwise:
from ..cpp_extensions.misc import multidim_transpose
colwise_output_spec = input_spec
if is_colwise_transposed:
colwise_output_spec = multidim_transpose(
colwise_output_spec, transpose_axis=flatten_axis
)
scale = (BATCHING + unique_var + "_scale_inv",)
return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
...@@ -658,7 +688,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -658,7 +688,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape, input_shape,
unique_var, unique_var,
flatten_axis, flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d, broadcast_2d_scale_shape_to_1d,
is_colwise_transposed,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
...@@ -666,15 +698,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -666,15 +698,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape: The shape of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization flatten_axis: Axis along which data can be flattened to 2D for quantization
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
is_colwise_transposed: Whether the column-wise tensors are transposed.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
# TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed is_rowwise = q_layout.has_rowwise
is_colwise = q_layout.has_colwise
input_rank = len(input_shape) input_rank = len(input_shape)
input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
flatten_axis = (flatten_axis + input_rank) % input_rank flatten_axis = (flatten_axis + input_rank) % input_rank
input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)]
assert ( assert (
self._block_dims[1] != 1 self._block_dims[1] != 1
...@@ -690,30 +725,56 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -690,30 +725,56 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
# We have to use two different factors in the two CompoundFactors because of Shardy # We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same. # verifier requirements, even though they are the same.
# No CompoundFactor is needed if the dim has the same size as the blocksize
blocksizes = {} blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None" rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == block_size_1d: colwise_var = f"{unique_var}_None"
if is_rowwise and not input_shape[-1] == block_size_1d:
rowwise_var = input_spec[-1] + "_compound" rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = block_size_1d blocksizes["blocksize_x"] = block_size_1d
if not input_shape[flatten_axis - 1] == block_size_1d: if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d:
colwise_var = input_spec[flatten_axis - 1] + "_compound" colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = block_size_1d blocksizes["blocksize_y"] = block_size_1d
# The rowwise and colwise scale tensors should be sharded the same way as the input. # The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies. # However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy() if is_rowwise:
rowwise[-1] = rowwise_var rowwise_out = input_spec.copy()
rowwise_scale = input_spec.copy()
rowwise_scale[-1] = rowwise_var
else:
rowwise_out = [
BATCHING + f"{unique_var}_rowwise_output",
]
rowwise_scale = [
BATCHING + f"{unique_var}_rowwise_scale_inv",
]
colwise = input_spec.copy() if is_colwise:
colwise[flatten_axis - 1] = colwise_var colwise_out = input_spec.copy()
colwise_scale = input_spec.copy()
colwise_scale[flatten_axis - 1] = colwise_var
if is_colwise_transposed:
from ..cpp_extensions.misc import multidim_transpose
colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis)
colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis)
else:
colwise_out = [
BATCHING + f"{unique_var}_colwise_output",
]
colwise_scale = [
BATCHING + f"{unique_var}_colwise_scale_inv",
]
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise_out),
tuple(colwise), tuple(rowwise_scale),
tuple(colwise_out),
tuple(colwise_scale),
blocksizes, blocksizes,
) )
...@@ -850,7 +911,8 @@ class ScalingMode(Enum): ...@@ -850,7 +911,8 @@ class ScalingMode(Enum):
self, self,
input_shape, input_shape,
unique_var, unique_var,
flatten_axis=-1, flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d=False, broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
...@@ -859,13 +921,19 @@ class ScalingMode(Enum): ...@@ -859,13 +921,19 @@ class ScalingMode(Enum):
input_shape: The shape of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization.
q_layout: The layout of the quantized tensor
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
return self._get_impl().get_shardy_sharding_rules( return self._get_impl().get_shardy_sharding_rules(
input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d input_shape,
unique_var,
flatten_axis,
q_layout,
broadcast_2d_scale_shape_to_1d,
self.is_colwise_transposed,
) )
def get_grouped_scale_shape_2x( def get_grouped_scale_shape_2x(
......
...@@ -15,10 +15,10 @@ from abc import ABC, abstractmethod ...@@ -15,10 +15,10 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode, TensorUsage from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap from .dequantizer import ScalingModeToDequantizerMap
from .misc import QuantizeLayout
from ..sharding import ( from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
) )
...@@ -128,9 +128,7 @@ class NoScaleTensor(AbstractBaseTensor1x): ...@@ -128,9 +128,7 @@ class NoScaleTensor(AbstractBaseTensor1x):
def get_tensor(self, usage: TensorUsage): def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage.""" """Returns the tensor based on the tensor usage."""
q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
assert ( assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor"
q_layout == QuantizeLayout.ROWWISE
), "Only ROWWISE layout is supported for NoScaleTensor"
return self return self
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
...@@ -264,8 +262,8 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -264,8 +262,8 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
def get_tensor(self, usage: TensorUsage): def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage.""" """Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage) q_layout = self.scaling_mode.get_quantize_layout(usage)
colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise
rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise
if colwise_usage_valid or rowwise_usage_valid: if colwise_usage_valid or rowwise_usage_valid:
return self return self
...@@ -301,16 +299,15 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -301,16 +299,15 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
data = with_sharding_constraint_by_logical_axes(self.data, axis_names) data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if self.scaling_mode.is_block_scaling: # Both MXFP8 and NVFP4
# TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else: else:
scale_inv = self.scale_inv scale_inv = self.scale_inv
return ScaledTensor1x( return ScaledTensor1x(
data=data, data=data,
scale_inv=scale_inv,
amax=self.amax, amax=self.amax,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype, dq_dtype=self.dq_dtype,
_dq_func=self._dq_func, _dq_func=self._dq_func,
...@@ -467,10 +464,10 @@ class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): ...@@ -467,10 +464,10 @@ class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
if q_layout_rowwise == QuantizeLayout.ROWWISE: if q_layout_rowwise.is_rowwise_only:
return self.rowwise_tensor return self.rowwise_tensor
if q_layout_colwise == QuantizeLayout.COLWISE: if q_layout_colwise.is_colwise_only:
return self.colwise_tensor return self.colwise_tensor
raise ValueError( raise ValueError(
...@@ -548,13 +545,13 @@ class ScaledTensorFactory: ...@@ -548,13 +545,13 @@ class ScaledTensorFactory:
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None: if group_sizes is not None:
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape)
assert ( assert (
original_shape is not None original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x" ), "original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors # Handling attrs of transposed tensors
group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis group_axis = (len(original_shape) + group_axis) % len(original_shape)
if data_layout == "T": if data_layout == "T":
if original_shape[0] == group_sizes.size: if original_shape[0] == group_sizes.size:
original_shape = ( original_shape = (
...@@ -587,7 +584,7 @@ class ScaledTensorFactory: ...@@ -587,7 +584,7 @@ class ScaledTensorFactory:
) )
# Handling attrs of transposed tensors # Handling attrs of transposed tensors
flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis flatten_axis = (data.ndim + flatten_axis) % data.ndim
if data_layout == "T": if data_layout == "T":
flatten_axis = data.ndim - flatten_axis flatten_axis = data.ndim - flatten_axis
...@@ -669,7 +666,7 @@ class ScaledTensorFactory: ...@@ -669,7 +666,7 @@ class ScaledTensorFactory:
colwise_amax, colwise_amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=True, # TODO(Phuong): set this correctly is_colwise=True,
data_layout=data_layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
...@@ -721,7 +718,7 @@ class ScaledTensorFactory: ...@@ -721,7 +718,7 @@ class ScaledTensorFactory:
""" """
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet" assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE: if q_layout.is_rowwise_colwise:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
data, data,
scale_inv, scale_inv,
...@@ -740,15 +737,14 @@ class ScaledTensorFactory: ...@@ -740,15 +737,14 @@ class ScaledTensorFactory:
colwise_has_rht_applied=colwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied,
) )
is_colwise = q_layout == QuantizeLayout.COLWISE if q_layout.is_colwise_only:
if is_colwise:
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
colwise_amax if colwise_amax is not None else amax, colwise_amax if colwise_amax is not None else amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=True,
data_layout=data_layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
...@@ -763,7 +759,7 @@ class ScaledTensorFactory: ...@@ -763,7 +759,7 @@ class ScaledTensorFactory:
amax, amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=False,
data_layout=data_layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment