"src/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "1fa56298c0bdee22fe5e38a80cde064ac67588bf"
Unverified Commit 962d9c53 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Scaling Enum Abstracting (#1655)



* scaling enum abstract

* rm NVTE_ from ScalingMode names

* rework scaling mode enum in grouped gemm

* fix norm sharding

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9d4e11ea
......@@ -448,8 +448,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -416,8 +416,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -327,8 +327,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING)
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
def is_shape_supported_by_mxfp8(input_shape):
try:
if isinstance(input_shape, type(pytest.param(0))):
input_shape = input_shape.values[0]
ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
return True
except:
# get_scale_shapes will raise an exception if the shape is not supported
......@@ -170,7 +170,7 @@ class TestActivation:
)
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
......@@ -198,7 +198,7 @@ class TestActivation:
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_layout=q_layout,
)
......@@ -223,7 +223,7 @@ class TestActivation:
self.activation_type = activation_type
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
output = tex.act_lu(x, activation_type, quantizer)
......@@ -345,7 +345,7 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
)
......@@ -420,7 +420,7 @@ class TestNorm:
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_layout=q_layout,
)
......@@ -437,7 +437,7 @@ class TestNorm:
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
q_layout=QuantizeLayout.ROWWISE_COLWISE,
)
......@@ -493,7 +493,7 @@ class TestQuantize:
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)
......@@ -533,7 +533,7 @@ class TestFusedQuantize:
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
......@@ -618,7 +618,7 @@ class TestFusedQuantize:
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=in_dtype,
scaling_mode=ScalingMode.NVTE_NO_SCALING,
scaling_mode=ScalingMode.NO_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=QuantizeLayout.ROWWISE,
......@@ -639,7 +639,7 @@ class TestFusedQuantize:
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
......@@ -670,7 +670,7 @@ class TestFusedQuantize:
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
......@@ -785,7 +785,7 @@ class TestDense:
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
)
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
......@@ -830,7 +830,7 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......@@ -886,7 +886,7 @@ class TestFusedDense:
x, w, gamma, beta
)
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
......@@ -916,7 +916,7 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......@@ -993,7 +993,7 @@ class TestFusedDense:
value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
......
......@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
}
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
......
......@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
......
......@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
......@@ -313,7 +313,7 @@ class BaseRunner:
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
)
......
......@@ -162,7 +162,7 @@ class ActLuPrimitive(BasePrimitive):
assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
)
return out
......@@ -282,7 +282,7 @@ class ActLuPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
......@@ -293,9 +293,9 @@ class ActLuPrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
......@@ -339,7 +339,7 @@ class ActLuPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
......@@ -350,9 +350,9 @@ class ActLuPrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
......@@ -391,7 +391,7 @@ class ActLuPrimitive(BasePrimitive):
)
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -463,7 +463,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
colwise_out_shape = out_shape
......@@ -545,7 +545,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
scaling_mode=scaling_mode,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
......@@ -673,7 +673,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
......@@ -691,9 +691,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
......@@ -743,7 +743,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
......@@ -761,9 +761,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
......@@ -810,7 +810,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -928,7 +928,7 @@ def act_lu(
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=act_len,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
......@@ -1042,7 +1042,7 @@ def quantize_dact_dbias(
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
......@@ -1095,7 +1095,7 @@ def quantize_dact_dbias(
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -98,7 +98,7 @@ class GroupedGemmPrimitive(BasePrimitive):
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
scaling_mode=scaling_mode.value,
)
@staticmethod
......@@ -123,7 +123,7 @@ class GroupedGemmPrimitive(BasePrimitive):
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
......@@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8(
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......@@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d(
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
......@@ -291,10 +291,10 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
......@@ -403,7 +403,7 @@ def grouped_gemm(
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
......@@ -415,7 +415,7 @@ def grouped_gemm(
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
......@@ -427,13 +427,13 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
......@@ -470,13 +470,13 @@ def grouped_gemm(
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
if scaling_mode == ScalingMode.NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None:
......@@ -493,8 +493,8 @@ def grouped_gemm(
# TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
if scaling_mode == ScalingMode.NO_SCALING:
scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
......@@ -505,7 +505,7 @@ def grouped_gemm(
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
......
......@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
)
if not should_apply_war:
......
......@@ -105,6 +105,26 @@ class NormFwdPrimitive(BasePrimitive):
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
......@@ -112,33 +132,13 @@ class NormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(out_dtype),
norm_type,
scaling_mode.value,
scaling_mode,
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
......@@ -274,9 +274,9 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
......@@ -364,6 +364,8 @@ class NormFwdPrimitive(BasePrimitive):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
......@@ -371,34 +373,27 @@ class NormFwdPrimitive(BasePrimitive):
"and hurt performance."
)
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out"
)
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
output = (
out_sharding,
colwise_out_sharding,
......@@ -427,8 +422,11 @@ class NormFwdPrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
......@@ -445,43 +443,30 @@ class NormFwdPrimitive(BasePrimitive):
f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
)
g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
desc="NormFwdPrimitive.rsigma",
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale"
)
scale_inv_sharding = scale_sharding.duplicate_with_new_description(
"NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -517,7 +502,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -824,7 +809,6 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
......@@ -835,7 +819,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
......@@ -845,7 +829,7 @@ def layernorm_fwd(
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
......@@ -864,7 +848,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
......@@ -873,7 +857,7 @@ def layernorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -882,7 +866,7 @@ def layernorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
......@@ -1017,7 +1001,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
......@@ -1027,7 +1011,7 @@ def rmsnorm_fwd(
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
......@@ -1046,7 +1030,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
......@@ -1055,7 +1039,7 @@ def rmsnorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1064,7 +1048,7 @@ def rmsnorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
......
......@@ -93,7 +93,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
......@@ -114,6 +114,10 @@ class DBiasQuantizePrimitive(BasePrimitive):
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -176,7 +180,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
ctx,
x,
scale,
scaling_mode=scaling_mode,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
......@@ -302,7 +306,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -322,9 +326,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -374,7 +378,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -394,9 +398,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -445,7 +449,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -588,7 +592,7 @@ def _quantize_dbias_impl(
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine {
namespace jax {
......@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
......@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training);
......@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......
......@@ -17,7 +17,7 @@ namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
......@@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
......@@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
}
if (is_2x) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
......@@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x) {
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
......@@ -172,8 +173,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
if (is_2x) {
auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
? output_trans_shape
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
......@@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
......@@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) {
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
......@@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
......@@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
}
if (is_2x) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
......@@ -295,8 +297,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(
!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
......@@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<int64_t>("act_enum"),
.Attr<bool>("is_dbias"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const int64_t &scaling_mode,
int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
......@@ -90,14 +90,17 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32,
std::vector<size_t>{1});
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32,
std::vector<size_t>{1});
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE;
......@@ -107,20 +110,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING);
TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING);
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
lhs_sinv_shape);
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else {
NVTE_ERROR("Unsupported scaling mode: ", scaling_mode);
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
}
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
......@@ -169,7 +167,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) {
Result_Type workspace_flatten, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
......@@ -207,7 +206,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.Attr<int64_t>("num_gemms")
.Attr<int64_t>("scaling_mode"),
.Attr<JAXX_Scaling_Mode>("scaling_mode"),
FFI_CudaGraph_Traits);
} // namespace jax
......
......@@ -40,5 +40,28 @@ enum class QuantizeLayout {
ROWWISE_COLWISE,
};
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
} // namespace jax
} // namespace transformer_engine
......@@ -14,7 +14,8 @@ namespace jax {
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto output_tensor = TensorWrapper(_scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
int temp = 1;
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
}
......@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(),
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
......@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, int scaling_mode, bool is_2x) {
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data();
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
......@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto output_tensor = TensorWrapper(_scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (is_fp8_dtype(out_dtype)) {
......@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf->dimensions().back()});
}
if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
......@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(),
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
......@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
......
......@@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values();
pybind11::enum_<NVTEScalingMode>(m, "NVTE_Scaling_Mode", pybind11::module_local())
.value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING)
.value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
.value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment