Unverified Commit 962d9c53 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Scaling Enum Abstracting (#1655)



* scaling enum abstract

* rm NVTE_ from ScalingMode names

* rework scaling mode enum in grouped gemm

* fix norm sharding

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9d4e11ea
...@@ -448,8 +448,8 @@ def encoder_parser(args): ...@@ -448,8 +448,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -416,8 +416,8 @@ def encoder_parser(args): ...@@ -416,8 +416,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -327,8 +327,8 @@ def encoder_parser(args): ...@@ -327,8 +327,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -306,8 +306,8 @@ def mnist_parser(args): ...@@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] ...@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)] LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available() is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = [] supported_scaling_modes = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
if is_fp8_supported: if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
if is_mxfp8_supported: if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
def is_shape_supported_by_mxfp8(input_shape): def is_shape_supported_by_mxfp8(input_shape):
try: try:
if isinstance(input_shape, type(pytest.param(0))): if isinstance(input_shape, type(pytest.param(0))):
input_shape = input_shape.values[0] input_shape = input_shape.values[0]
ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
return True return True
except: except:
# get_scale_shapes will raise an exception if the shape is not supported # get_scale_shapes will raise an exception if the shape is not supported
...@@ -170,7 +170,7 @@ class TestActivation: ...@@ -170,7 +170,7 @@ class TestActivation:
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
...@@ -198,7 +198,7 @@ class TestActivation: ...@@ -198,7 +198,7 @@ class TestActivation:
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_layout=q_layout, q_layout=q_layout,
) )
...@@ -223,7 +223,7 @@ class TestActivation: ...@@ -223,7 +223,7 @@ class TestActivation:
self.activation_type = activation_type self.activation_type = activation_type
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
) )
output = tex.act_lu(x, activation_type, quantizer) output = tex.act_lu(x, activation_type, quantizer)
...@@ -345,7 +345,7 @@ class TestNorm: ...@@ -345,7 +345,7 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype, q_dtype=out_dtype,
q_layout=q_layout, q_layout=q_layout,
) )
...@@ -420,7 +420,7 @@ class TestNorm: ...@@ -420,7 +420,7 @@ class TestNorm:
epsilon=epsilon, epsilon=epsilon,
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_layout=q_layout, q_layout=q_layout,
) )
...@@ -437,7 +437,7 @@ class TestNorm: ...@@ -437,7 +437,7 @@ class TestNorm:
epsilon=epsilon, epsilon=epsilon,
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.MXFP8_1D_SCALING,
q_layout=QuantizeLayout.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
) )
...@@ -493,7 +493,7 @@ class TestQuantize: ...@@ -493,7 +493,7 @@ class TestQuantize:
if flatten_axis == -2: if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:] input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype) x = jax.random.uniform(key, input_shape, in_dtype)
...@@ -533,7 +533,7 @@ class TestFusedQuantize: ...@@ -533,7 +533,7 @@ class TestFusedQuantize:
def test_quantize_dbias( def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
): ):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape input_shape
): ):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
...@@ -618,7 +618,7 @@ class TestFusedQuantize: ...@@ -618,7 +618,7 @@ class TestFusedQuantize:
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=in_dtype, out_dtype=in_dtype,
scaling_mode=ScalingMode.NVTE_NO_SCALING, scaling_mode=ScalingMode.NO_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
...@@ -639,7 +639,7 @@ class TestFusedQuantize: ...@@ -639,7 +639,7 @@ class TestFusedQuantize:
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_layout=q_layout, q_layout=q_layout,
...@@ -670,7 +670,7 @@ class TestFusedQuantize: ...@@ -670,7 +670,7 @@ class TestFusedQuantize:
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.MXFP8_1D_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_layout=q_layout, q_layout=q_layout,
...@@ -785,7 +785,7 @@ class TestDense: ...@@ -785,7 +785,7 @@ class TestDense:
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
) )
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
...@@ -830,7 +830,7 @@ class TestFusedDense: ...@@ -830,7 +830,7 @@ class TestFusedDense:
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -886,7 +886,7 @@ class TestFusedDense: ...@@ -886,7 +886,7 @@ class TestFusedDense:
x, w, gamma, beta x, w, gamma, beta
) )
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -916,7 +916,7 @@ class TestFusedDense: ...@@ -916,7 +916,7 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -993,7 +993,7 @@ class TestFusedDense: ...@@ -993,7 +993,7 @@ class TestFusedDense:
value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
......
...@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = { ...@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
} }
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
......
...@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory ...@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
......
...@@ -39,7 +39,7 @@ def enable_fused_attn(): ...@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
QUANTIZE_RECIPES = [] QUANTIZE_RECIPES = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
...@@ -313,7 +313,7 @@ class BaseRunner: ...@@ -313,7 +313,7 @@ class BaseRunner:
test_others, test_others,
test_layer, test_layer,
) )
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME updated_state[0], QuantizeConfig.COLLECTION_NAME
) )
......
...@@ -162,7 +162,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -162,7 +162,7 @@ class ActLuPrimitive(BasePrimitive):
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)( out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
) )
return out return out
...@@ -282,7 +282,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -282,7 +282,7 @@ class ActLuPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
...@@ -293,9 +293,9 @@ class ActLuPrimitive(BasePrimitive): ...@@ -293,9 +293,9 @@ class ActLuPrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
if is_2x: if is_2x:
...@@ -339,7 +339,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -339,7 +339,7 @@ class ActLuPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
...@@ -350,9 +350,9 @@ class ActLuPrimitive(BasePrimitive): ...@@ -350,9 +350,9 @@ class ActLuPrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
if is_2x: if is_2x:
...@@ -391,7 +391,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -391,7 +391,7 @@ class ActLuPrimitive(BasePrimitive):
) )
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -463,7 +463,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -463,7 +463,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else: else:
colwise_out_shape = out_shape colwise_out_shape = out_shape
...@@ -545,7 +545,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -545,7 +545,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
is_2x=is_2x, is_2x=is_2x,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=int(act_enum), act_enum=int(act_enum),
...@@ -673,7 +673,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -673,7 +673,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
...@@ -691,9 +691,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -691,9 +691,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if is_2x: if is_2x:
...@@ -743,7 +743,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -743,7 +743,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
...@@ -761,9 +761,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -761,9 +761,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if is_2x: if is_2x:
...@@ -810,7 +810,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -810,7 +810,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else: else:
global_dbias = local_dbias global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -928,7 +928,7 @@ def act_lu( ...@@ -928,7 +928,7 @@ def act_lu(
out_dtype=x.dtype, out_dtype=x.dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()), scale_shapes=((), ()),
...@@ -1042,7 +1042,7 @@ def quantize_dact_dbias( ...@@ -1042,7 +1042,7 @@ def quantize_dact_dbias(
# outputs float32 for dbias accumulation # outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype), out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset # default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused is_2x=False, # unused
scale_dtype=jnp.float32, # unused scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused scale_shapes=((), ()), # unused
...@@ -1095,7 +1095,7 @@ def quantize_dact_dbias( ...@@ -1095,7 +1095,7 @@ def quantize_dact_dbias(
) )
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax) quantizer.update(updated_amax)
......
...@@ -98,7 +98,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -98,7 +98,7 @@ class GroupedGemmPrimitive(BasePrimitive):
bias_contig, bias_contig,
dim_list, dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=int(scaling_mode), scaling_mode=scaling_mode.value,
) )
@staticmethod @staticmethod
...@@ -123,7 +123,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -123,7 +123,7 @@ class GroupedGemmPrimitive(BasePrimitive):
bias_contig, bias_contig,
dim_list, dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, out_flat_size=out_flat_size,
) )
...@@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8( ...@@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8(
): ):
"""FP8 GEMM for XLA pattern match""" """FP8 GEMM for XLA pattern match"""
assert ( assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode" ), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
...@@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d(
JAX GEMM for MXFP8 via scaled_matmul JAX GEMM for MXFP8 via scaled_matmul
""" """
assert ( assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode" ), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
...@@ -291,10 +291,10 @@ def _jax_gemm( ...@@ -291,10 +291,10 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
...@@ -403,7 +403,7 @@ def grouped_gemm( ...@@ -403,7 +403,7 @@ def grouped_gemm(
rhs_shape = rhs.data.shape rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2" ), "FP8 GEMM does not support E5M2 * E5M2"
...@@ -415,7 +415,7 @@ def grouped_gemm( ...@@ -415,7 +415,7 @@ def grouped_gemm(
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else: else:
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN # For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape lhs_shape = lhs.shape
rhs_shape = rhs.shape rhs_shape = rhs.shape
out_dtype = lhs.dtype out_dtype = lhs.dtype
...@@ -427,13 +427,13 @@ def grouped_gemm( ...@@ -427,13 +427,13 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn) lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn) lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
...@@ -470,13 +470,13 @@ def grouped_gemm( ...@@ -470,13 +470,13 @@ def grouped_gemm(
dims.append((bm, bn, k)) dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1)) lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1)) rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None: if bias_list is not None:
...@@ -493,8 +493,8 @@ def grouped_gemm( ...@@ -493,8 +493,8 @@ def grouped_gemm(
# TE/common does not support NVTE_NO_SCALING yet # TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NVTE_NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs # Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind( out_contig = GroupedGemmPrimitive.outer_primitive.bind(
...@@ -505,7 +505,7 @@ def grouped_gemm( ...@@ -505,7 +505,7 @@ def grouped_gemm(
bias_contig, bias_contig,
dim_list, dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, out_flat_size=out_flat_size,
) )
......
...@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, ...@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
""" """
should_apply_war = ( should_apply_war = (
quantizer is not None quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x() and quantizer.is_2x2x()
) )
if not should_apply_war: if not should_apply_war:
......
...@@ -105,6 +105,26 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -105,6 +105,26 @@ class NormFwdPrimitive(BasePrimitive):
if norm_type == NVTE_Norm_Type.LayerNorm: if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size gamma_aval.size, # hidden size
...@@ -112,33 +132,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -112,33 +132,13 @@ class NormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
norm_type, norm_type,
scaling_mode.value, scaling_mode,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
get_forward_sm_margin(), get_forward_sm_margin(),
is_2x, is_2x,
) )
wkspace_aval = jax.core.ShapedArray(
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
...@@ -274,9 +274,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -274,9 +274,9 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_outer=False, is_outer=False,
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
x.shape, is_padded=False scaling_mode
) ).get_scale_shape_2x(x.shape, is_padded=False)
# slice out padding for mxfp8, noop for DelayedScaling # slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape rowwise_scale_inv_shape
...@@ -364,6 +364,8 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -364,6 +364,8 @@ class NormFwdPrimitive(BasePrimitive):
del zero_centered_gamma, epsilon, out_dtype, result_infos del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
...@@ -371,34 +373,27 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -371,34 +373,27 @@ class NormFwdPrimitive(BasePrimitive):
"and hurt performance." "and hurt performance."
) )
out_sharding = NamedSharding( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding( rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
) )
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
if norm_type == NVTE_Norm_Type.RMSNorm: mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
output = ( output = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -427,8 +422,11 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -427,8 +422,11 @@ class NormFwdPrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2]) g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3]) b_spec = get_padded_spec(arg_infos[3])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
...@@ -445,43 +443,30 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -445,43 +443,30 @@ class NormFwdPrimitive(BasePrimitive):
f"{NormFwdPrimitive.name} does not support sharding of parameter beta " f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
x_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
)
g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding( rsigma_sharding = NamedSharding(
mesh, mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
desc="NormFwdPrimitive.rsigma",
) )
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
if norm_type == NVTE_Norm_Type.RMSNorm: mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_sharding = NamedSharding( scale_inv_spec = amax_spec = (None,)
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
) scale_inv_spec = amax_spec = scale_spec
scale_inv_sharding = scale_sharding.duplicate_with_new_description( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
"NormFwdPrimitive.scale_inv" scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
) )
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -517,7 +502,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -517,7 +502,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -824,7 +809,6 @@ def layernorm_fwd( ...@@ -824,7 +809,6 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer) if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32) else jnp.ones((1,), dtype=jnp.float32)
) )
if quantizer is None: if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
...@@ -835,7 +819,7 @@ def layernorm_fwd( ...@@ -835,7 +819,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)), scale_shapes=((1,), (1,)),
...@@ -845,7 +829,7 @@ def layernorm_fwd( ...@@ -845,7 +829,7 @@ def layernorm_fwd(
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False is_2x2x = False
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -864,7 +848,7 @@ def layernorm_fwd( ...@@ -864,7 +848,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape),
...@@ -873,7 +857,7 @@ def layernorm_fwd( ...@@ -873,7 +857,7 @@ def layernorm_fwd(
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -882,7 +866,7 @@ def layernorm_fwd( ...@@ -882,7 +866,7 @@ def layernorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor # The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False x.shape, is_padded=False
) )
...@@ -1017,7 +1001,7 @@ def rmsnorm_fwd( ...@@ -1017,7 +1001,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()), scale_shapes=((), ()),
...@@ -1027,7 +1011,7 @@ def rmsnorm_fwd( ...@@ -1027,7 +1011,7 @@ def rmsnorm_fwd(
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False is_2x2x = False
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -1046,7 +1030,7 @@ def rmsnorm_fwd( ...@@ -1046,7 +1030,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape),
...@@ -1055,7 +1039,7 @@ def rmsnorm_fwd( ...@@ -1055,7 +1039,7 @@ def rmsnorm_fwd(
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -1064,7 +1048,7 @@ def rmsnorm_fwd( ...@@ -1064,7 +1048,7 @@ def rmsnorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor # The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False x.shape, is_padded=False
) )
......
...@@ -93,7 +93,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -93,7 +93,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else: else:
colwise_out_shape = out_shape colwise_out_shape = out_shape
...@@ -114,6 +114,10 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -114,6 +114,10 @@ class DBiasQuantizePrimitive(BasePrimitive):
gi_hidden_size, gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
) )
wkspace_shape = wkspace_info[0] wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -176,7 +180,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -176,7 +180,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
ctx, ctx,
x, x,
scale, scale,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -302,7 +306,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -302,7 +306,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
...@@ -322,9 +326,9 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -322,9 +326,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
...@@ -374,7 +378,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -374,7 +378,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
...@@ -394,9 +398,9 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -394,9 +398,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
...@@ -445,7 +449,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -445,7 +449,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -588,7 +592,7 @@ def _quantize_dbias_impl( ...@@ -588,7 +592,7 @@ def _quantize_dbias_impl(
is_outer=True, is_outer=True,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax) quantizer.update(updated_amax)
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "utils.h" #include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D ...@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
...@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); ...@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype, DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode, NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin, bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training); bool is_training);
...@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); ...@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); QuantizeLayout q_layout);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
// Softmax // Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......
...@@ -17,7 +17,7 @@ namespace jax { ...@@ -17,7 +17,7 @@ namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) { bool is_2x_int) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back(); auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int); auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
...@@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
...@@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
} }
if (is_2x) { if (is_2x) {
auto &tmp_shape = auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; ? scale_inv_buf
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { : colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1}); std::vector<size_t>{1});
...@@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size}; auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto dact_input_tensor = auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype); TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
if (is_2x) { if (is_2x) {
auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING) auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
? output_trans_shape : output_shape;
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
...@@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
} }
if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
...@@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
bool is_dbias, int64_t act_enum) { int64_t act_enum, bool is_2x, bool is_dbias) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
...@@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
...@@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
} }
if (is_2x) { if (is_2x) {
auto &tmp_shape = auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; ? scale_inv_buf
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { : colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1}); std::vector<size_t>{1});
...@@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK( NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), "TE/common does not support delayed scaling for 2x with gated activations.");
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
switch (act_type) { switch (act_type) {
...@@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias"),
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const int64_t &scaling_mode, int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode,
cudaStream_t stream) { cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
...@@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1}; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1}; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr, auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr)); lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr, rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
nullptr, reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i)); if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
rhs_wrapper_list.push_back(std::move(rhs_i)); lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32,
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { std::vector<size_t>{1});
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32,
std::vector<size_t>{1});
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k); MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE; size_t sinv_k = k / MXFP8_BLOCK_SIZE;
...@@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
rhs_sinv_shape[1] = sinv_k; rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering // Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING);
TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING);
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0, lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
lhs_sinv_shape); lhs_sinv_shape);
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0, rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape); rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else { } else {
NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype); auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes; lhs_ptr += m * k * lhs_dtype_bytes;
...@@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, ...@@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten, Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { Result_Type workspace_flatten, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode) {
// Inputs // Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data()); auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data()); auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
...@@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, ...@@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Ret<Buffer_Type>() // out_flatten .Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten .Ret<Buffer_Type>() // workspace_flatten
.Attr<int64_t>("num_gemms") .Attr<int64_t>("num_gemms")
.Attr<int64_t>("scaling_mode"), .Attr<JAXX_Scaling_Mode>("scaling_mode"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
...@@ -40,5 +40,28 @@ enum class QuantizeLayout { ...@@ -40,5 +40,28 @@ enum class QuantizeLayout {
ROWWISE_COLWISE, ROWWISE_COLWISE,
}; };
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -14,7 +14,8 @@ namespace jax { ...@@ -14,7 +14,8 @@ namespace jax {
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype, DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode, NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin, bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training) { bool is_training) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto output_tensor = TensorWrapper(_scaling_mode);
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
int temp = 1; int temp = 1;
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape); output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
} }
...@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else { } else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma."); "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(),
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
...@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon, int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, int scaling_mode, bool is_2x) { int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data(); auto *workspace = wkspace_buf->untyped_data();
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x); auto _is_2x = static_cast<bool>(is_2x);
...@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto output_tensor = TensorWrapper(_scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf->dimensions().back()}); scale_inv_buf->dimensions().back()});
} }
if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
...@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream); workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else { } else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma."); "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(),
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
...@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("zero_centered_gamma") .Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
......
...@@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("RMSNorm", NVTE_Norm_Type::RMSNorm) .value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values(); .export_values();
pybind11::enum_<NVTEScalingMode>(m, "NVTE_Scaling_Mode", pybind11::module_local()) pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
.value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment