Unverified Commit c47f329b authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)



* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent b10f436a
...@@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.quantization import ( ...@@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
NoScaleTensor,
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
ScaledTensor2x, ScaledTensor2x,
...@@ -182,7 +183,7 @@ ACTIVATION_TYPES = { ...@@ -182,7 +183,7 @@ ACTIVATION_TYPES = {
class TestActivation: class TestActivation:
def ref_act(self, x, activation_type): def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type) return _jax_act_lu(x, activation_type).data
def value_n_grad_ref_func(self, x, activation_type): def value_n_grad_ref_func(self, x, activation_type):
jitted_reference = jit( jitted_reference = jit(
...@@ -337,8 +338,8 @@ class TestNorm: ...@@ -337,8 +338,8 @@ class TestNorm:
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else: else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
# if isinstance(ln_out, ScaledTensor): # This is a no-op for non-quantized data
# ln_out = ln_out.dequantize() ln_out = ln_out.dequantize()
return ln_out return ln_out
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -765,7 +766,9 @@ class TestFusedQuantize: ...@@ -765,7 +766,9 @@ class TestFusedQuantize:
te_output, jax_output, precise_comparison=precise_comparison te_output, jax_output, precise_comparison=precise_comparison
) )
else: else:
assert_allclose(te_output, jax_output) assert isinstance(te_output, NoScaleTensor)
assert isinstance(jax_output, NoScaleTensor)
assert_allclose(te_output.data, jax_output.data)
if is_dbias: if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
...@@ -1020,7 +1023,6 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -1020,7 +1023,6 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else: else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
if isinstance(ln_out, ScaledTensor):
ln_out = ln_out.dequantize() ln_out = ln_out.dequantize()
return ln_out return ln_out
...@@ -1177,7 +1179,7 @@ class TestFusedDense: ...@@ -1177,7 +1179,7 @@ class TestFusedDense:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type) x = _jax_act_lu(linear_1_out, activation_type).data
linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias: if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
......
...@@ -173,7 +173,9 @@ class TestDistributedLayernormMLP: ...@@ -173,7 +173,9 @@ class TestDistributedLayernormMLP:
) )
# Single GPU # Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
):
single_jitter = jax.jit( single_jitter = jax.jit(
value_and_grad_func, value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
...@@ -184,7 +186,7 @@ class TestDistributedLayernormMLP: ...@@ -184,7 +186,7 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast( with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
): ):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
...@@ -226,7 +228,12 @@ class TestDistributedLayernormMLP: ...@@ -226,7 +228,12 @@ class TestDistributedLayernormMLP:
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
if multi_grads[i] is not None: if multi_grads[i] is not None:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
...@@ -252,7 +259,7 @@ class TestDistributedLayernormMLP: ...@@ -252,7 +259,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, self,
...@@ -281,7 +288,7 @@ class TestDistributedLayernormMLP: ...@@ -281,7 +288,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy( def test_layernorm_mlp_grad_shardy(
self, self,
......
...@@ -14,7 +14,7 @@ import jax.numpy as jnp ...@@ -14,7 +14,7 @@ import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor from .quantize.tensor import NoScaleTensor
from .quantize.quantizer import Quantizer from .quantize.quantizer import Quantizer
...@@ -22,7 +22,7 @@ def activation( ...@@ -22,7 +22,7 @@ def activation(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization. """Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor. This function applies a sequence of activation functions to the input tensor.
...@@ -72,7 +72,7 @@ def _activation_fwd_rule(x, activation_type, quantizer): ...@@ -72,7 +72,7 @@ def _activation_fwd_rule(x, activation_type, quantizer):
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
fwd_output = tex.act_lu(x, activation_type, quantizer) fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor): # This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize() fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer) return fwd_output, (x, quantizer)
...@@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
dx = dx.data
return (dx, None) return (dx, None)
......
...@@ -29,7 +29,7 @@ from .misc import ( ...@@ -29,7 +29,7 @@ from .misc import (
) )
from .quantization import _jax_dbias, _quantize_dbias_impl from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout, QuantizeLayout,
...@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
""" """
JAX native activation implementation JAX native activation implementation
""" """
...@@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S ...@@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
x = jnp.squeeze(x, axis=-2) x = jnp.squeeze(x, axis=-2)
if quantizer: if quantizer:
return quantizer.quantize(x, flatten_axis=-1) return quantizer.quantize(x, flatten_axis=-1)
return x return NoScaleTensor(data=x, amax=None)
def _jax_quantize_dact_dbias( def _jax_quantize_dact_dbias(
dz: jnp.ndarray, dz: Union[jnp.ndarray, NoScaleTensor],
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True, is_dbias: bool = True,
...@@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias( ...@@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias(
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
) )
(dx,) = vjp_func(dz.astype(jnp.float32)) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
(dx,) = vjp_func(dz)
dbias = None dbias = None
if is_dbias: if is_dbias:
...@@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias( ...@@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias(
dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else: else:
dx = dx.astype(x.dtype) dx = dx.astype(x.dtype)
dx = NoScaleTensor(data=dx, amax=None)
return dx, dbias return dx, dbias
...@@ -981,7 +984,6 @@ def act_lu( ...@@ -981,7 +984,6 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -990,7 +992,6 @@ def act_lu( ...@@ -990,7 +992,6 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1035,9 +1036,9 @@ def act_lu( ...@@ -1035,9 +1036,9 @@ def act_lu(
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor: out = NoScaleTensor(
return ScaledTensorFactory.create_2x( data=out,
out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype amax=None,
) )
return out return out
...@@ -1092,7 +1093,6 @@ def quantize_dact_dbias( ...@@ -1092,7 +1093,6 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1103,7 +1103,6 @@ def quantize_dact_dbias( ...@@ -1103,7 +1103,6 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...@@ -1146,19 +1145,10 @@ def quantize_dact_dbias( ...@@ -1146,19 +1145,10 @@ def quantize_dact_dbias(
if is_dbias: if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor: output = NoScaleTensor(
return ( data=output,
ScaledTensorFactory.create_2x( amax=None,
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
) )
return output, dbias return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
...@@ -1167,7 +1157,7 @@ def quantize_dact_dbias( ...@@ -1167,7 +1157,7 @@ def quantize_dact_dbias(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
) )
return _quantize_dbias_impl( return _quantize_dbias_impl(
out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
) )
is_gated = act_len == 2 is_gated = act_len == 2
...@@ -1194,7 +1184,7 @@ def quantize_dact_dbias( ...@@ -1194,7 +1184,7 @@ def quantize_dact_dbias(
quantizer=None, quantizer=None,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
) )
return out, dbias return out, dbias
...@@ -1258,7 +1248,6 @@ def dact_lu( ...@@ -1258,7 +1248,6 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1268,7 +1257,6 @@ def dact_lu( ...@@ -1268,7 +1257,6 @@ def dact_lu(
x: Input tensor that was used in forward pass. x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied. activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient. quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
...@@ -1279,6 +1267,5 @@ def dact_lu( ...@@ -1279,6 +1267,5 @@ def dact_lu(
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
) )
return output return output
...@@ -22,6 +22,8 @@ from transformer_engine_jax import get_num_compute_streams ...@@ -22,6 +22,8 @@ from transformer_engine_jax import get_num_compute_streams
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize from .quantization import grouped_quantize
from ..quantize import ( from ..quantize import (
AbstractBaseTensor,
NoScaleTensor,
ScaledTensor, ScaledTensor,
ScaledTensor2x, ScaledTensor2x,
GroupedScaledTensor1x, GroupedScaledTensor1x,
...@@ -228,6 +230,11 @@ class GemmPrimitive(BasePrimitive): ...@@ -228,6 +230,11 @@ class GemmPrimitive(BasePrimitive):
"require non-transposed LHS and transposed RHS operands " "require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)." "(`contracting_dims=((-1, ), (-1, ))`)."
) )
else:
assert lhs.dtype == rhs.dtype, (
"For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal."
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
)
# Determine output shape and dtype # Determine output shape and dtype
assert ( assert (
...@@ -1134,8 +1141,8 @@ def _jax_gemm( ...@@ -1134,8 +1141,8 @@ def _jax_gemm(
def gemm( def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, AbstractBaseTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, AbstractBaseTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
...@@ -1191,6 +1198,11 @@ def gemm( ...@@ -1191,6 +1198,11 @@ def gemm(
compute the GeLU contribution to the gradient. Only supported with TE's custom call to compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM. cuBLAS GEMM.
""" """
if isinstance(lhs, NoScaleTensor):
lhs = lhs.data
if isinstance(rhs, NoScaleTensor):
rhs = rhs.data
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if lhs_quantizer is None or rhs_quantizer is None: if lhs_quantizer is None or rhs_quantizer is None:
quantizer_set = kwargs.get("quantizer_set", None) quantizer_set = kwargs.get("quantizer_set", None)
......
...@@ -30,7 +30,7 @@ from .misc import ( ...@@ -30,7 +30,7 @@ from .misc import (
) )
from .quantization import _quantize_dbias_impl from .quantization import _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout, QuantizeLayout,
...@@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ...@@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else: else:
ln_out = jnp.asarray(output).astype(x.dtype) ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1) return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1)
...@@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ...@@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
ln_out = quantizer.quantize(output, dq_dtype=x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else: else:
ln_out = jnp.asarray(output).astype(x.dtype) ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(rsigma, axis=-1) return ln_out, jnp.squeeze(rsigma, axis=-1)
...@@ -930,7 +932,7 @@ def layernorm_fwd( ...@@ -930,7 +932,7 @@ def layernorm_fwd(
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
is_outer=True, is_outer=True,
) )
return output, mu, rsigma return NoScaleTensor(data=output, amax=None), mu, rsigma
if ( if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
...@@ -1064,7 +1066,7 @@ def layernorm_bwd( ...@@ -1064,7 +1066,7 @@ def layernorm_bwd(
) )
mu_empty = jnp.zeros(mu.shape, mu.dtype) mu_empty = jnp.zeros(mu.shape, mu.dtype)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, mu_empty, rsigma_empty)) return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty))
return NormBwdPrimitive.outer_primitive.bind( return NormBwdPrimitive.outer_primitive.bind(
dz, dz,
x, x,
...@@ -1133,14 +1135,14 @@ def rmsnorm_fwd( ...@@ -1133,14 +1135,14 @@ def rmsnorm_fwd(
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
is_outer=True, is_outer=True,
) )
return output, rsigma return NoScaleTensor(data=output, amax=None), rsigma
if ( if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
): ):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer) out, _ = _quantize_dbias_impl(out.data, quantizer)
return out, rsigma return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -1152,7 +1154,9 @@ def rmsnorm_fwd( ...@@ -1152,7 +1154,9 @@ def rmsnorm_fwd(
epsilon=epsilon, epsilon=epsilon,
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype
)
return out, rsigma return out, rsigma
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
...@@ -1254,7 +1258,7 @@ def rmsnorm_bwd( ...@@ -1254,7 +1258,7 @@ def rmsnorm_bwd(
gamma, gamma,
) )
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, rsigma_empty)) return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty))
mu = jnp.empty(()) mu = jnp.empty(())
dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind( dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind(
dz, dz,
...@@ -1276,7 +1280,6 @@ def normalization_fwd( ...@@ -1276,7 +1280,6 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1293,7 +1296,6 @@ def normalization_fwd( ...@@ -1293,7 +1296,6 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1321,20 +1323,6 @@ def normalization_fwd( ...@@ -1321,20 +1323,6 @@ def normalization_fwd(
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
mu,
rsigma,
)
return output, mu, rsigma return output, mu, rsigma
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
import operator import operator
from functools import reduce from functools import reduce
from typing import Tuple, Optional from typing import Tuple, Optional, Union
import math import math
from packaging import version from packaging import version
...@@ -38,6 +38,7 @@ from ..quantize import ( ...@@ -38,6 +38,7 @@ from ..quantize import (
QuantizeLayout, QuantizeLayout,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor,
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
...@@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
7, 7,
8, 8,
9, 9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -535,11 +536,15 @@ def _jax_quantize( ...@@ -535,11 +536,15 @@ def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
): ):
if quantizer is None: if quantizer is None:
if isinstance(x, NoScaleTensor):
return x return x
return NoScaleTensor(data=x, amax=None)
return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
if isinstance(dx, NoScaleTensor):
dx = dx.data
sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!" assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype dtype = dtype or dx.dtype
...@@ -558,7 +563,9 @@ def _jax_quantize_dbias( ...@@ -558,7 +563,9 @@ def _jax_quantize_dbias(
flatten_axis: int = -1, flatten_axis: int = -1,
): ):
if quantizer is None: if quantizer is None:
if isinstance(x, NoScaleTensor):
return x, None return x, None
return NoScaleTensor(data=x, amax=None), None
return ( return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
...@@ -566,12 +573,11 @@ def _jax_quantize_dbias( ...@@ -566,12 +573,11 @@ def _jax_quantize_dbias(
def _quantize_dbias_impl( def _quantize_dbias_impl(
x: jnp.ndarray, x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -581,28 +587,15 @@ def _quantize_dbias_impl( ...@@ -581,28 +587,15 @@ def _quantize_dbias_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
# Early-exit for non-quantized call # Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype dq_dtype = dq_dtype or x.data.dtype
if quantizer is None: if quantizer is None:
dbias = None dbias = None
if is_dbias: if is_dbias:
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
if noop_scaled_tensor:
# Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
# always works.
return (
ScaledTensorFactory.create_2x(
x,
None,
x,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
)
return x, dbias return x, dbias
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
...@@ -630,20 +623,24 @@ def _quantize_dbias_impl( ...@@ -630,20 +623,24 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
scale = jnp.empty((), jnp.float32) scale = jnp.empty((), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) amax = x.amax
if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# Make sure amax is init with zero # Make sure amax is init with zero
if amax is None:
amax = jnp.zeros((1,), jnp.float32) amax = jnp.zeros((1,), jnp.float32)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
...@@ -665,7 +662,7 @@ def _quantize_dbias_impl( ...@@ -665,7 +662,7 @@ def _quantize_dbias_impl(
updated_amax, updated_amax,
dbias, dbias,
) = PrimitiveClass.outer_primitive.bind( ) = PrimitiveClass.outer_primitive.bind(
x, x.data,
scale, scale,
amax, amax,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
...@@ -706,10 +703,9 @@ def _quantize_dbias_impl( ...@@ -706,10 +703,9 @@ def _quantize_dbias_impl(
def quantize( def quantize(
x: jnp.ndarray, x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -719,7 +715,6 @@ def quantize( ...@@ -719,7 +715,6 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None. is None.
Returns: Returns:
...@@ -729,17 +724,15 @@ def quantize( ...@@ -729,17 +724,15 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
return out return out
def quantize_dbias( def quantize_dbias(
dz: jnp.ndarray, dz: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -750,8 +743,6 @@ def quantize_dbias( ...@@ -750,8 +743,6 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -765,7 +756,6 @@ def quantize_dbias( ...@@ -765,7 +756,6 @@ def quantize_dbias(
quantizer=quantizer, quantizer=quantizer,
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
...@@ -968,7 +958,9 @@ def grouped_quantize( ...@@ -968,7 +958,9 @@ def grouped_quantize(
""" """
if quantizer is None: if quantizer is None:
if isinstance(x, NoScaleTensor):
return x return x
return NoScaleTensor(data=x, amax=None)
# TODO(Phuong): add support for flatten_axis = -2 # TODO(Phuong): add support for flatten_axis = -2
assert flatten_axis in ( assert flatten_axis in (
......
...@@ -24,6 +24,7 @@ from .quantize import ( ...@@ -24,6 +24,7 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -80,14 +81,10 @@ def dense( ...@@ -80,14 +81,10 @@ def dense(
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
# Remove when tex.quantize() can handle quantizer=None if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): input_dtype = x.dtype
x = with_sharding_constraint_by_logical_axes(x, input_axes) kernel = kernel.astype(input_dtype)
output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense( output = _dense(
x, x,
kernel, kernel,
...@@ -175,7 +172,9 @@ def _dense_fwd_rule( ...@@ -175,7 +172,9 @@ def _dense_fwd_rule(
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize( casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -183,7 +182,6 @@ def _dense_fwd_rule( ...@@ -183,7 +182,6 @@ def _dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -240,7 +238,6 @@ def _dense_bwd_rule( ...@@ -240,7 +238,6 @@ def _dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# GEMM NT # GEMM NT
......
...@@ -17,7 +17,6 @@ import jax.numpy as jnp ...@@ -17,7 +17,6 @@ import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize import ( from .quantize import (
ScaledTensor,
Quantizer, Quantizer,
) )
...@@ -112,7 +111,7 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps ...@@ -112,7 +111,7 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps
output, mu, rsigma = tex.normalization_fwd( output, mu, rsigma = tex.normalization_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
) )
if isinstance(output, ScaledTensor): # This is a no-op for higher-precision tensors
output = output.dequantize() output = output.dequantize()
return output, (x, mu, rsigma, gamma, beta, quantizer) return output, (x, mu, rsigma, gamma, beta, quantizer)
......
...@@ -22,6 +22,7 @@ from .quantize import ( ...@@ -22,6 +22,7 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -68,6 +69,11 @@ def layernorm_dense( ...@@ -68,6 +69,11 @@ def layernorm_dense(
- The function supports automatic differentiation through JAX's custom VJP - The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel - Quantization is applied to both the normalized input and kernel
""" """
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
output = _layernorm_dense( output = _layernorm_dense(
x, x,
kernel, kernel,
...@@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule( ...@@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize( casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True kernel,
flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule( ...@@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......
...@@ -27,6 +27,7 @@ from .quantize import ( ...@@ -27,6 +27,7 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -104,6 +105,11 @@ def layernorm_mlp( ...@@ -104,6 +105,11 @@ def layernorm_mlp(
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
output = _layernorm_mlp( output = _layernorm_mlp(
x, x,
gamma, gamma,
...@@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize( casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
) )
# NN GEMM # NN GEMM
...@@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule( ...@@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule(
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu( casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True dot_1_output,
activation_type,
quantizer=ffn2_quantizer_set.x,
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize( casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True kernel_2,
quantizer=ffn2_quantizer_set.kernel,
) )
# NN GEMM # NN GEMM
...@@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule( ...@@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias( casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True grad,
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
...@@ -19,7 +19,13 @@ from transformer_engine_jax import QuantizeLayout ...@@ -19,7 +19,13 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import (
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
ScaledTensorFactory,
NoScaleTensor,
)
from .helper import ( from .helper import (
get_quantize_config, get_quantize_config,
get_quantize_config_class, get_quantize_config_class,
...@@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer):
data_layout: str = "NT" data_layout: str = "NT"
def _quantize_func( def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 self,
x: Union[jnp.ndarray, NoScaleTensor],
is_colwise=False,
dq_dtype=None,
flatten_axis=-1,
) -> ScaledTensor1x: ) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8. """Quantize function helper for delayed scaling FP8.
...@@ -229,14 +239,17 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -229,14 +239,17 @@ class CurrentScaleQuantizer(Quantizer):
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32 compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,)) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.astype(compute_dtype) * scale scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / scale scale_inv = 1.0 / scale
...@@ -263,7 +276,10 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -263,7 +276,10 @@ class CurrentScaleQuantizer(Quantizer):
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
if flatten_axis < 0: if flatten_axis < 0:
flatten_axis += x.ndim flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
...@@ -347,11 +363,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -347,11 +363,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32 compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale scaled_x = x.data.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging # quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype # compute_dtype = x.dtype
...@@ -360,7 +379,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -360,7 +379,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,))) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
self.update(amax)
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data=clipped_scaled_x, data=clipped_scaled_x,
scale_inv=scale_inv, scale_inv=scale_inv,
...@@ -460,6 +480,10 @@ class BlockScaleQuantizer(Quantizer): ...@@ -460,6 +480,10 @@ class BlockScaleQuantizer(Quantizer):
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
if isinstance(x, NoScaleTensor):
# No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x.
x = x.data
# TODO(Phuong): use quantize_func from JAX # TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0: if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis flatten_axis = x.ndim + flatten_axis
......
...@@ -166,6 +166,90 @@ class ScalingModeMetadataImpl(ABC): ...@@ -166,6 +166,90 @@ class ScalingModeMetadataImpl(ABC):
""" """
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for no scaling mode.
This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
"""
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.
Returns:
The data type used for scale tensors (float32)
"""
return jnp.float32
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise, is_padded, flatten_axis
return (0,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return QuantizeLayout.ROWWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del data_shape, group_axis, is_colwise
assert isinstance(n_groups, int)
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for current scaling mode. """Implementation for current scaling mode.
...@@ -740,5 +824,5 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { ...@@ -740,5 +824,5 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR # WAR
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
} }
...@@ -25,6 +25,8 @@ from ..sharding import ( ...@@ -25,6 +25,8 @@ from ..sharding import (
__all__ = [ __all__ = [
"TensorUsage", "TensorUsage",
"AbstractBaseTensor",
"NoScaleTensor",
"ScaledTensor", "ScaledTensor",
"ScaledTensor1x", "ScaledTensor1x",
"ScaledTensor2x", "ScaledTensor2x",
...@@ -34,14 +36,9 @@ __all__ = [ ...@@ -34,14 +36,9 @@ __all__ = [
] ]
@register_pytree_node_class
@dataclass @dataclass
class ScaledTensor(ABC): class AbstractBaseTensor(ABC):
"""Abstract base class for scaled tensors. """Abstract base class for all tensor types."""
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
@classmethod @classmethod
def tree_unflatten(cls, aux_data, children): def tree_unflatten(cls, aux_data, children):
...@@ -93,9 +90,76 @@ class ScaledTensor(ABC): ...@@ -93,9 +90,76 @@ class ScaledTensor(ABC):
""" """
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
"""Abstract base class for single layout tensors."""
data: jnp.ndarray
amax: jnp.ndarray
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
class ScaledTensor1x(ScaledTensor): class NoScaleTensor(AbstractBaseTensor1x):
"""Higher-precision tensor."""
def __post_init__(self):
assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.amax)
aux_data = ()
return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying array."""
return self.data.ndim
def dequantize(self):
"""This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
return self.data
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
assert (
q_layout == QuantizeLayout.ROWWISE
), "Only ROWWISE layout is supported for NoScaleTensor"
return self
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)
return NoScaleTensor(
data=data,
amax=self.amax,
)
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors."""
@register_pytree_node_class
@dataclass
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
"""Single-scale quantized tensor implementation. """Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor, This class represents a tensor quantized with a single scaling factor,
...@@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor):
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
""" """
data: jnp.ndarray
scale_inv: jnp.ndarray scale_inv: jnp.ndarray
amax: jnp.ndarray
scaling_mode: ScalingMode scaling_mode: ScalingMode
dq_dtype: jnp.dtype dq_dtype: jnp.dtype
_dq_func: Callable _dq_func: Callable
...@@ -154,7 +216,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -154,7 +216,7 @@ class ScaledTensor1x(ScaledTensor):
Returns: Returns:
A tuple containing (children, aux_data) for tree operations A tuple containing (children, aux_data) for tree operations
""" """
children = (self.data, self.scale_inv, self.amax) children = (self.data, self.amax, self.scale_inv)
aux_data = ( aux_data = (
self.scaling_mode, self.scaling_mode,
self.dq_dtype, self.dq_dtype,
...@@ -274,15 +336,15 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -274,15 +336,15 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.original_shape = original_shape self.original_shape = original_shape
self.group_axis = group_axis self.group_axis = group_axis
super().__init__( super().__init__(
data, data=data,
scale_inv, scale_inv=scale_inv,
amax, amax=amax,
scaling_mode, scaling_mode=scaling_mode,
dq_dtype, dq_dtype=dq_dtype,
_dq_func, _dq_func=_dq_func,
is_colwise, is_colwise=is_colwise,
data_layout, data_layout=data_layout,
flatten_axis, flatten_axis=flatten_axis,
) )
def __post_init__(self): def __post_init__(self):
...@@ -339,7 +401,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -339,7 +401,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
class ScaledTensor2x(ScaledTensor): class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
"""Double-scale quantized tensor implementation. """Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors. This class represents a tensor quantized with both row-wise and column-wise scaling factors.
...@@ -503,15 +565,15 @@ class ScaledTensorFactory: ...@@ -503,15 +565,15 @@ class ScaledTensorFactory:
flatten_axis = data.ndim - flatten_axis flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x( return ScaledTensor1x(
data, data=data,
scale_inv, scale_inv=scale_inv,
amax, amax=amax,
scaling_mode, scaling_mode=scaling_mode,
dq_dtype, dq_dtype=dq_dtype,
dequantizer.dequantize, _dq_func=dequantizer.dequantize,
is_colwise, is_colwise=is_colwise,
data_layout, data_layout=data_layout,
flatten_axis, flatten_axis=flatten_axis,
) )
@staticmethod @staticmethod
...@@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . ...@@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
if isinstance(x, GroupedScaledTensor1x): if isinstance(x, GroupedScaledTensor1x):
raise NotImplementedError raise NotImplementedError
if isinstance(x, ScaledTensor): if isinstance(x, AbstractBaseTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment