Unverified Commit 9a819334 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] JAX Current Scaling (#1647)



* [JAX-Q] Single GPU current scaling for JAX
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix scale check dtype for MXFP8 scales affecting tests using assert_bitwise_scaled_tensors
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove cast to fp32 for norm primitives now that zero-centered gamma dtype issue is fixed
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix lint issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove unnecessary cast to fp32
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a1c18bc8
......@@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
......@@ -18,7 +19,11 @@ from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
from transformer_engine.jax.cpp_extensions.normalization import _jax_layernorm, _jax_rmsnorm
from transformer_engine.jax.cpp_extensions.normalization import (
_jax_layernorm,
_jax_rmsnorm,
is_norm_zero_centered_gamma_in_weight_dtype,
)
from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize,
_jax_quantize_dbias,
......@@ -55,6 +60,7 @@ supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
......@@ -72,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape):
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
assert_allclose(a.data, b.data)
assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scale_inv.dtype == jnp.float8_e8m0fnu:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
else:
assert_allclose(a.scale_inv, b.scale_inv)
assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
......@@ -160,7 +172,12 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_grad_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
......@@ -171,7 +188,7 @@ class TestActivation:
)
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
......@@ -189,8 +206,11 @@ class TestActivation:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_forward_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
......@@ -199,7 +219,7 @@ class TestActivation:
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=q_layout,
)
......@@ -336,8 +356,20 @@ class TestNorm:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_grad_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
"""
Test transformer_engine.jax.layernorm.layernorm
......@@ -346,9 +378,7 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
......@@ -396,12 +426,41 @@ class TestNorm:
)
ref_mu = None
if get_cudnn_version() < (9, 10, 0):
precise_comparison = True
if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison = False
elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison = False
elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison = False
if precise_comparison:
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
assert_bitwise_scaled_tensors(output, ref_out)
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype)
......@@ -412,8 +471,20 @@ class TestNorm:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_forward_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
......@@ -426,7 +497,7 @@ class TestNorm:
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_layout=q_layout,
)
......@@ -636,16 +707,19 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_quantize_dact_dbias_tensor_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
):
self._test_quantize_dact_dbias(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
......@@ -836,7 +910,10 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......@@ -922,7 +999,10 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......
......@@ -120,6 +120,11 @@ class ActLuPrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}"
)
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused activation and quantization. Please"
" do activation in higher-precision then quantize with current tensor scaling."
)
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
......@@ -500,6 +505,12 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused dact and quantization. Please do"
" dact in higher-precision then quantize with current tensor scaling."
)
ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
......@@ -512,7 +523,10 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
colwise_out_shape = out_shape
......@@ -718,6 +732,10 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Partitioned current tensor scaling is not yet supported."
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
......@@ -1026,6 +1044,16 @@ def act_lu(
out = out.reshape(output_shape)
return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu(
x=x.astype(jnp.float32),
activation_type=activation_type,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -1101,8 +1129,12 @@ def quantize_dact_dbias(
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out = dact_lu(dz, x, activation_type, quantizer=None)
return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
out = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
)
return _quantize_dbias_impl(
out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
)
is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet
......@@ -1145,6 +1177,19 @@ def quantize_dact_dbias(
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu(
dz=dz.astype(jnp.float32),
x=x.astype(jnp.float32),
activation_type=activation_type,
quantizer=None,
)
out, dbias = _quantize_dbias_impl(
out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
)
return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -1184,7 +1229,7 @@ def quantize_dact_dbias(
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -155,7 +155,7 @@ def _dequantize(x, scale_inv, dq_dtype):
4,
),
)
def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
......@@ -173,12 +173,13 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
return out_3d
def _jax_gemm_delayed_scaling_fp8(
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
assert rhs.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......@@ -196,7 +197,7 @@ def _jax_gemm_delayed_scaling_fp8(
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
......@@ -271,8 +272,11 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......@@ -378,7 +382,7 @@ def grouped_gemm(
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if lhs.scaling_mode.is_tensor_scaling():
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
......@@ -406,7 +410,7 @@ def grouped_gemm(
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
elif scaling_mode.is_tensor_scaling():
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......@@ -443,7 +447,7 @@ def grouped_gemm(
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if scaling_mode.is_tensor_scaling():
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......
......@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
from ..quantize import ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType
......@@ -215,9 +215,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
)
if not should_apply_war:
return None
......
......@@ -28,6 +28,7 @@ from .misc import (
NamedSharding,
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
......@@ -36,7 +37,6 @@ from ..quantize import (
DelayedScaleQuantizer,
ScalingMode,
)
from .quantization import _quantize_dbias_impl
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -136,10 +136,19 @@ class NormFwdPrimitive(BasePrimitive):
f" {FUSED_MXFP8_NORM_CUDNN_MIN_VERSION} or higher"
)
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling."
)
mu_rsigama_dtype = jnp.float32
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size
assert gamma_aval.dtype == beta_aval.dtype, (
f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and "
f"{beta_aval.dtype}"
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
......@@ -937,9 +946,22 @@ def layernorm_fwd(
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, mu, rsigma = layernorm_fwd(
x=x,
gamma=gamma,
beta=beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
(
rowwise_casted_output,
......@@ -967,7 +989,10 @@ def layernorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1127,9 +1152,21 @@ def rmsnorm_fwd(
out, _ = _quantize_dbias_impl(out, quantizer)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, rsigma = rmsnorm_fwd(
x=x,
gamma=gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
(
rowwise_casted_output,
......@@ -1157,7 +1194,10 @@ def rmsnorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......
......@@ -94,7 +94,10 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
......@@ -299,6 +302,11 @@ class DBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
......@@ -371,6 +379,11 @@ class DBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del result_infos, is_outer
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
......@@ -635,7 +648,7 @@ def _quantize_dbias_impl(
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -44,6 +44,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
......@@ -152,6 +157,11 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
// Evil hack to specify TE impl
// Note: nvte_quantize_dbias_dgelu chooses its internal impl based
// on what pointers are allocated, e.g. whether to output with
......@@ -219,6 +229,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
......
......@@ -108,7 +108,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
......
......@@ -44,6 +44,7 @@ enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
......@@ -57,6 +58,9 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
......
......@@ -24,7 +24,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
// empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
......@@ -98,7 +98,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto workspace_shape = std::vector<size_t>{workspace_size};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(gamma, gamma_shape, in_dtype);
auto gamma_tensor = TensorWrapper(gamma, gamma_shape, w_dtype);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
......@@ -107,6 +107,11 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -134,6 +139,8 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
}
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()),
"gamma and beta must have the same data type.");
auto beta_tensor = TensorWrapper(beta, gamma_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
......
......@@ -142,6 +142,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
......@@ -7,6 +7,7 @@
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......@@ -107,12 +108,15 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
......@@ -142,11 +146,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
: output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (is_tensor_scaling) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
......@@ -159,6 +161,21 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
}
}
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
nvte_compute_amax(input_tensor.data(), // input data
output_tensor.data(), // output data (for amax)
stream);
QuantizationConfigWrapper quant_config;
/** defaults for now, TODO(Jeremy) move to parameter */
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
nvte_compute_scale_from_amax(output_tensor.data(), quant_config, stream);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
......
......@@ -85,6 +85,7 @@ class Dequantizer:
funcs = {
ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
}
......
......@@ -94,7 +94,7 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if scaling_mode.is_tensor_scaling():
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch)
......
......@@ -27,6 +27,7 @@ __all__ = [
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"CurrentScaleQuantizer",
"DelayedScaleQuantizer",
"BlockScaleQuantizer",
"QuantizerFactory",
......@@ -159,37 +160,19 @@ class Quantizer(ABC):
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(Quantizer):
"""Quantizer implementation using delayed scaling.
class CurrentScaleQuantizer(Quantizer):
"""Quantizer implementation using current scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
This quantizer uses current scaling mode with float32 scales
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def get_data_layout(self) -> str:
"""Get the data data_layout string.
......@@ -217,15 +200,18 @@ class DelayedScaleQuantizer(Quantizer):
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = self.scale.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
amax = jnp.max(jnp.abs(x)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
......@@ -233,8 +219,7 @@ class DelayedScaleQuantizer(Quantizer):
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
scale_inv = 1.0 / scale
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
......@@ -294,6 +279,75 @@ class DelayedScaleQuantizer(Quantizer):
return colwise_tensor
return rowwise_tensor
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(CurrentScaleQuantizer):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
@staticmethod
@jax.jit
def _update_amax_history(amax_history, new_amax):
......@@ -531,6 +585,7 @@ class QuantizerFactory:
quantizer_type_map = {
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
}
......
......@@ -95,10 +95,10 @@ class ScalingModeMetadataImpl(ABC):
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for current scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
This implementation provides metadata for current scaling mode, including scale data type and shape.
"""
def get_scale_dtype(self) -> jnp.dtype:
......@@ -148,6 +148,13 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
......@@ -317,12 +324,14 @@ class ScalingMode(Enum):
This class defines the available scaling modes for tensor quantization:
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
- NO_SCALING: No scaling applied
"""
NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......@@ -395,6 +404,25 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def is_tensor_scaling(self) -> bool:
"""Check if this scaling mode is per-tensor scaling.
Returns:
True if the scaling mode is tensor scaling, False otherwise
"""
return self in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
)
def is_1d_block_scaling(self) -> bool:
"""Check if this scaling mode is 1D block scaling.
Returns:
True if the scaling mode is 1D block scaling, False otherwise
"""
return self == ScalingMode.MXFP8_1D_SCALING
def __eq__(self, other):
"""Compare this scaling mode with another.
......@@ -434,5 +462,6 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment