Unverified Commit c47f329b authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)



* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent b10f436a
......@@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
NoScaleTensor,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
......@@ -182,7 +183,7 @@ ACTIVATION_TYPES = {
class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type)
return _jax_act_lu(x, activation_type).data
def value_n_grad_ref_func(self, x, activation_type):
jitted_reference = jit(
......@@ -337,8 +338,8 @@ class TestNorm:
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
# if isinstance(ln_out, ScaledTensor):
# ln_out = ln_out.dequantize()
# This is a no-op for non-quantized data
ln_out = ln_out.dequantize()
return ln_out
key = jax.random.PRNGKey(0)
......@@ -765,7 +766,9 @@ class TestFusedQuantize:
te_output, jax_output, precise_comparison=precise_comparison
)
else:
assert_allclose(te_output, jax_output)
assert isinstance(te_output, NoScaleTensor)
assert isinstance(jax_output, NoScaleTensor)
assert_allclose(te_output.data, jax_output.data)
if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
......@@ -1020,7 +1023,6 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
if isinstance(ln_out, ScaledTensor):
ln_out = ln_out.dequantize()
return ln_out
......@@ -1177,7 +1179,7 @@ class TestFusedDense:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type)
x = _jax_act_lu(linear_1_out, activation_type).data
linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
......
......@@ -173,7 +173,9 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
......@@ -184,7 +186,7 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
......@@ -226,7 +228,12 @@ class TestDistributedLayernormMLP:
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
if multi_grads[i] is not None:
if isinstance(multi_grads[i], list):
......@@ -252,7 +259,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self,
......@@ -281,7 +288,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self,
......
......@@ -14,7 +14,7 @@ import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.tensor import NoScaleTensor
from .quantize.quantizer import Quantizer
......@@ -22,7 +22,7 @@ def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
......@@ -72,7 +72,7 @@ def _activation_fwd_rule(x, activation_type, quantizer):
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
# This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
......@@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
dx = dx.data
return (dx, None)
......
......@@ -29,7 +29,7 @@ from .misc import (
)
from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
......@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
"""
JAX native activation implementation
"""
......@@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
x = jnp.squeeze(x, axis=-2)
if quantizer:
return quantizer.quantize(x, flatten_axis=-1)
return x
return NoScaleTensor(data=x, amax=None)
def _jax_quantize_dact_dbias(
dz: jnp.ndarray,
dz: Union[jnp.ndarray, NoScaleTensor],
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True,
......@@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias(
_, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
)
(dx,) = vjp_func(dz.astype(jnp.float32))
# VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
(dx,) = vjp_func(dz)
dbias = None
if is_dbias:
......@@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias(
dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else:
dx = dx.astype(x.dtype)
dx = NoScaleTensor(data=dx, amax=None)
return dx, dbias
......@@ -981,7 +984,6 @@ def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
......@@ -990,7 +992,6 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
If quantizer is None:
......@@ -1035,9 +1036,9 @@ def act_lu(
is_outer=True,
)
out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype
out = NoScaleTensor(
data=out,
amax=None,
)
return out
......@@ -1092,7 +1093,6 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization.
......@@ -1103,7 +1103,6 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
......@@ -1146,19 +1145,10 @@ def quantize_dact_dbias(
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
output = NoScaleTensor(
data=output,
amax=None,
)
return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
......@@ -1167,7 +1157,7 @@ def quantize_dact_dbias(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
)
return _quantize_dbias_impl(
out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
)
is_gated = act_len == 2
......@@ -1194,7 +1184,7 @@ def quantize_dact_dbias(
quantizer=None,
)
out, dbias = _quantize_dbias_impl(
out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
)
return out, dbias
......@@ -1258,7 +1248,6 @@ def dact_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""
Backward pass for activation with optional quantization.
......@@ -1268,7 +1257,6 @@ def dact_lu(
x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
The gradient of the activation with respect to the input.
......@@ -1279,6 +1267,5 @@ def dact_lu(
activation_type=activation_type,
is_dbias=False,
quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
)
return output
......@@ -22,6 +22,8 @@ from transformer_engine_jax import get_num_compute_streams
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import (
AbstractBaseTensor,
NoScaleTensor,
ScaledTensor,
ScaledTensor2x,
GroupedScaledTensor1x,
......@@ -228,6 +230,11 @@ class GemmPrimitive(BasePrimitive):
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
else:
assert lhs.dtype == rhs.dtype, (
"For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal."
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
)
# Determine output shape and dtype
assert (
......@@ -1134,8 +1141,8 @@ def _jax_gemm(
def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
lhs: Union[jnp.ndarray, AbstractBaseTensor],
rhs: Union[jnp.ndarray, AbstractBaseTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
......@@ -1191,6 +1198,11 @@ def gemm(
compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM.
"""
if isinstance(lhs, NoScaleTensor):
lhs = lhs.data
if isinstance(rhs, NoScaleTensor):
rhs = rhs.data
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if lhs_quantizer is None or rhs_quantizer is None:
quantizer_set = kwargs.get("quantizer_set", None)
......
......@@ -30,7 +30,7 @@ from .misc import (
)
from .quantization import _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
......@@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1)
......@@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
ln_out = NoScaleTensor(data=ln_out, amax=None)
return ln_out, jnp.squeeze(rsigma, axis=-1)
......@@ -930,7 +932,7 @@ def layernorm_fwd(
scale_dtype=jnp.float32,
is_outer=True,
)
return output, mu, rsigma
return NoScaleTensor(data=output, amax=None), mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
......@@ -1064,7 +1066,7 @@ def layernorm_bwd(
)
mu_empty = jnp.zeros(mu.shape, mu.dtype)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, mu_empty, rsigma_empty))
return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty))
return NormBwdPrimitive.outer_primitive.bind(
dz,
x,
......@@ -1133,14 +1135,14 @@ def rmsnorm_fwd(
scale_dtype=jnp.float32,
is_outer=True,
)
return output, rsigma
return NoScaleTensor(data=output, amax=None), rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer)
out, _ = _quantize_dbias_impl(out.data, quantizer)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
......@@ -1152,7 +1154,9 @@ def rmsnorm_fwd(
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype
)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
......@@ -1254,7 +1258,7 @@ def rmsnorm_bwd(
gamma,
)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, rsigma_empty))
return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty))
mu = jnp.empty(())
dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind(
dz,
......@@ -1276,7 +1280,6 @@ def normalization_fwd(
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
):
"""Common wrapper for normalization forward pass.
......@@ -1293,7 +1296,6 @@ def normalization_fwd(
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
A tuple containing:
......@@ -1321,20 +1323,6 @@ def normalization_fwd(
else:
raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
mu,
rsigma,
)
return output, mu, rsigma
......
......@@ -4,7 +4,7 @@
"""JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
import math
from packaging import version
......@@ -38,6 +38,7 @@ from ..quantize import (
QuantizeLayout,
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
......@@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
7,
8,
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -535,11 +536,15 @@ def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x
return NoScaleTensor(data=x, amax=None)
return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
if isinstance(dx, NoScaleTensor):
dx = dx.data
sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype
......@@ -558,7 +563,9 @@ def _jax_quantize_dbias(
flatten_axis: int = -1,
):
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x, None
return NoScaleTensor(data=x, amax=None), None
return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
......@@ -566,12 +573,11 @@ def _jax_quantize_dbias(
def _quantize_dbias_impl(
x: jnp.ndarray,
x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -581,28 +587,15 @@ def _quantize_dbias_impl(
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype
dq_dtype = dq_dtype or x.data.dtype
if quantizer is None:
dbias = None
if is_dbias:
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
if noop_scaled_tensor:
# Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
# always works.
return (
ScaledTensorFactory.create_2x(
x,
None,
x,
None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, dbias
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
......@@ -630,20 +623,24 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
scale = jnp.empty((), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
amax = x.amax
if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure amax is init with zero
if amax is None:
amax = jnp.zeros((1,), jnp.float32)
# It is faster to use 1x quantization for tensor scaling
......@@ -665,7 +662,7 @@ def _quantize_dbias_impl(
updated_amax,
dbias,
) = PrimitiveClass.outer_primitive.bind(
x,
x.data,
scale,
amax,
out_dtype=quantizer.q_dtype,
......@@ -706,10 +703,9 @@ def _quantize_dbias_impl(
def quantize(
x: jnp.ndarray,
x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -719,7 +715,6 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns:
......@@ -729,17 +724,15 @@ def quantize(
x,
quantizer=quantizer,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
return out
def quantize_dbias(
dz: jnp.ndarray,
dz: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer,
is_dbias: bool = True,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -750,8 +743,6 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns:
A tuple containing:
......@@ -765,7 +756,6 @@ def quantize_dbias(
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
......@@ -968,7 +958,9 @@ def grouped_quantize(
"""
if quantizer is None:
if isinstance(x, NoScaleTensor):
return x
return NoScaleTensor(data=x, amax=None)
# TODO(Phuong): add support for flatten_axis = -2
assert flatten_axis in (
......
......@@ -24,6 +24,7 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
get_quantize_config,
)
......@@ -80,14 +81,10 @@ def dense(
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
output = _dense(
x,
kernel,
......@@ -175,7 +172,9 @@ def _dense_fwd_rule(
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -183,7 +182,6 @@ def _dense_fwd_rule(
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -240,7 +238,6 @@ def _dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# GEMM NT
......
......@@ -17,7 +17,6 @@ import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import (
ScaledTensor,
Quantizer,
)
......@@ -112,7 +111,7 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps
output, mu, rsigma = tex.normalization_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
)
if isinstance(output, ScaledTensor):
# This is a no-op for higher-precision tensors
output = output.dequantize()
return output, (x, mu, rsigma, gamma, beta, quantizer)
......
......@@ -22,6 +22,7 @@ from .quantize import (
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
get_quantize_config,
)
......@@ -68,6 +69,11 @@ def layernorm_dense(
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel
"""
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
output = _layernorm_dense(
x,
kernel,
......@@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule(
epsilon,
norm_type,
quantizer=quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
kernel,
flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......
......@@ -27,6 +27,7 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
TensorUsage,
get_quantize_config,
)
......@@ -104,6 +105,11 @@ def layernorm_mlp(
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
output = _layernorm_mlp(
x,
gamma,
......@@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule(
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
)
# NN GEMM
......@@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule(
# (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
dot_1_output,
activation_type,
quantizer=ffn2_quantizer_set.x,
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
)
# NN GEMM
......@@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
grad,
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
......@@ -19,7 +19,13 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .tensor import (
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
ScaledTensorFactory,
NoScaleTensor,
)
from .helper import (
get_quantize_config,
get_quantize_config_class,
......@@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer):
data_layout: str = "NT"
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
self,
x: Union[jnp.ndarray, NoScaleTensor],
is_colwise=False,
dq_dtype=None,
flatten_axis=-1,
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
......@@ -229,14 +239,17 @@ class CurrentScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,))
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.astype(compute_dtype) * scale
scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / scale
......@@ -263,7 +276,10 @@ class CurrentScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
......@@ -347,11 +363,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if isinstance(x, jnp.ndarray):
x = NoScaleTensor(data=x, amax=None)
dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
scaled_x = x.data.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
......@@ -360,7 +379,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
self.update(amax)
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
......@@ -460,6 +480,10 @@ class BlockScaleQuantizer(Quantizer):
Returns:
A ScaledTensor1x containing the quantized data
"""
if isinstance(x, NoScaleTensor):
# No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x.
x = x.data
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
......
......@@ -166,6 +166,90 @@ class ScalingModeMetadataImpl(ABC):
"""
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for no scaling mode.
This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
"""
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.
Returns:
The data type used for scale tensors (float32)
"""
return jnp.float32
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise, is_padded, flatten_axis
return (0,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return QuantizeLayout.ROWWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del data_shape, group_axis, is_colwise
assert isinstance(n_groups, int)
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for current scaling mode.
......@@ -740,5 +824,5 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
}
......@@ -25,6 +25,8 @@ from ..sharding import (
__all__ = [
"TensorUsage",
"AbstractBaseTensor",
"NoScaleTensor",
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
......@@ -34,14 +36,9 @@ __all__ = [
]
@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors.
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
class AbstractBaseTensor(ABC):
"""Abstract base class for all tensor types."""
@classmethod
def tree_unflatten(cls, aux_data, children):
......@@ -93,9 +90,76 @@ class ScaledTensor(ABC):
"""
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
"""Abstract base class for single layout tensors."""
data: jnp.ndarray
amax: jnp.ndarray
@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
class NoScaleTensor(AbstractBaseTensor1x):
"""Higher-precision tensor."""
def __post_init__(self):
assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.amax)
aux_data = ()
return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying array."""
return self.data.ndim
def dequantize(self):
"""This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
return self.data
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
assert (
q_layout == QuantizeLayout.ROWWISE
), "Only ROWWISE layout is supported for NoScaleTensor"
return self
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)
return NoScaleTensor(
data=data,
amax=self.amax,
)
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors."""
@register_pytree_node_class
@dataclass
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
"""Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor,
......@@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor):
flatten_axis: The quantization axis for the tensor
"""
data: jnp.ndarray
scale_inv: jnp.ndarray
amax: jnp.ndarray
scaling_mode: ScalingMode
dq_dtype: jnp.dtype
_dq_func: Callable
......@@ -154,7 +216,7 @@ class ScaledTensor1x(ScaledTensor):
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv, self.amax)
children = (self.data, self.amax, self.scale_inv)
aux_data = (
self.scaling_mode,
self.dq_dtype,
......@@ -274,15 +336,15 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.original_shape = original_shape
self.group_axis = group_axis
super().__init__(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
_dq_func,
is_colwise,
data_layout,
flatten_axis,
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=_dq_func,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
def __post_init__(self):
......@@ -339,7 +401,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
"""Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors.
......@@ -503,15 +565,15 @@ class ScaledTensorFactory:
flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x(
data,
scale_inv,
amax,
scaling_mode,
dq_dtype,
dequantizer.dequantize,
is_colwise,
data_layout,
flatten_axis,
data=data,
scale_inv=scale_inv,
amax=amax,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.dequantize,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
@staticmethod
......@@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
if isinstance(x, GroupedScaledTensor1x):
raise NotImplementedError
if isinstance(x, ScaledTensor):
if isinstance(x, AbstractBaseTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment