Unverified Commit 3a298e6b authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)



* TensorUsage + FP8 GEMM with all layouts handling on BW
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ae572af0
...@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): ...@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else: else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype) assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x): elif isinstance(a, ScaledTensor2x):
assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b) assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b) assert_dequantized_scaled_tensor(a.colwise_tensor, b)
else: else:
pytest.fail("a must be a ScaledTensor object") pytest.fail("a must be a ScaledTensor object")
...@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor( ...@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i = dq_a_i.reshape(b_i.shape) dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype) assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x): elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x) assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x) assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b) assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b) assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
else: else:
pytest.fail("a must be a GroupedScaledTensor object") pytest.fail("a must be a GroupedScaledTensor object")
......
...@@ -24,10 +24,11 @@ from ..quantize import ( ...@@ -24,10 +24,11 @@ from ..quantize import (
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
) )
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] __all__ = ["gemm", "grouped_gemm"]
num_cublas_streams = get_num_compute_streams() num_cublas_streams = get_num_compute_streams()
...@@ -40,11 +41,6 @@ def get_cublas_workspace_size_bytes() -> None: ...@@ -40,11 +41,6 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304 return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
""" """
Primitive for grouped GEMM Primitive for grouped GEMM
...@@ -338,10 +334,15 @@ def _jax_gemm( ...@@ -338,10 +334,15 @@ def _jax_gemm(
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor):
if quantizer_set != noop_quantizer_set: if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel) assert type(quantizer_set.x) is type(quantizer_set.kernel)
if (
quantizer_set.x.scaling_mode.is_tensor_scaling()
and is_fp8_gemm_with_all_layouts_supported()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize( lhs_q = quantizer_set.x.quantize(
lhs, lhs,
is_rowwise=lhs_is_rowwise, is_rowwise=lhs_is_rowwise,
...@@ -491,16 +492,13 @@ def grouped_gemm( ...@@ -491,16 +492,13 @@ def grouped_gemm(
assert type(quantizer_set.x) is type(quantizer_set.kernel) assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode scaling_mode = quantizer_set.x.scaling_mode
if ( if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later quantizer_set.x.scaling_mode.is_tensor_scaling()
# scaling_mode.is_tensor_scaling() and is_fp8_gemm_with_all_layouts_supported()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
): ):
lhs_is_rowwise = True lhs_is_rowwise = rhs_is_rowwise = True
rhs_is_rowwise = False
else: else:
lhs_is_rowwise = not lhs_is_trans lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans rhs_is_rowwise = rhs_is_trans
quantizer_set.x.q_layout = ( quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
) )
...@@ -515,6 +513,8 @@ def grouped_gemm( ...@@ -515,6 +513,8 @@ def grouped_gemm(
rhs_data = rhs_q.data rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv
lhs_shape = lhs_q.original_shape
rhs_shape = rhs_q.original_shape
assert not ( assert not (
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
...@@ -522,24 +522,35 @@ def grouped_gemm( ...@@ -522,24 +522,35 @@ def grouped_gemm(
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required # thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported():
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T" lhs_layout_is_T = lhs.data_layout == "T"
rhs_layout_is_T = rhs.data_layout == "T" rhs_layout_is_T = rhs.data_layout == "T"
else: else:
lhs_layout_is_T = lhs_q.data_layout == "T" lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T" rhs_layout_is_T = rhs_q.data_layout == "T"
# we can't apply _shape_normalization on the grouped input
# thus we need to ensure that lhs is in N and rhs is in T
assert (
lhs_is_trans == lhs_layout_is_T
), "lhs input must be transposed before calling grouped_gemm"
assert (
not rhs_is_trans == rhs_layout_is_T
), "rhs input must be transposed before calling grouped_gemm"
lhs_is_trans = False
rhs_is_trans = True
lhs_ndim = len(lhs_shape) lhs_ndim = len(lhs_shape)
rhs_ndim = len(rhs_shape) rhs_ndim = len(rhs_shape)
if lhs_layout_is_T: if lhs_layout_is_T:
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T: if rhs_layout_is_T:
# For rhs [G, K, N], need to exclude the G dim from contract_dim
if group_sizes.size == rhs_shape[0]:
rhs_contract_dim = tuple(
(rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim
)
else:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T)
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
# Calling GroupedGEMM Custom Call # Calling GroupedGEMM Custom Call
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
......
...@@ -19,6 +19,7 @@ from .quantize import ( ...@@ -19,6 +19,7 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage,
) )
...@@ -105,8 +106,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, ...@@ -105,8 +106,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
# GEMM NN # GEMM NN
output = tex.gemm( output = tex.gemm(
casted_x.get_rowwise_tensor(), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_colwise_tensor(), casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
...@@ -116,8 +117,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, ...@@ -116,8 +117,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
x.shape, x.shape,
kernel.shape, kernel.shape,
use_bias, use_bias,
...@@ -138,8 +139,8 @@ def _dense_bwd_rule( ...@@ -138,8 +139,8 @@ def _dense_bwd_rule(
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
colwise_casted_x, casted_x_lhs,
rowwise_casted_kernel, casted_kernel_rhs,
x_shape, x_shape,
kernel_shape, kernel_shape,
use_bias, use_bias,
...@@ -161,8 +162,8 @@ def _dense_bwd_rule( ...@@ -161,8 +162,8 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(usage=TensorUsage.LHS),
rowwise_casted_kernel, casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim), (g_contracting_dim, k_contracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -174,7 +175,9 @@ def _dense_bwd_rule( ...@@ -174,7 +175,9 @@ def _dense_bwd_rule(
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dim, g_contracting_dim),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -287,13 +290,6 @@ def _grouped_dense_fwd_rule( ...@@ -287,13 +290,6 @@ def _grouped_dense_fwd_rule(
"and k_contracting_dims=(1,) for now, " "and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}" f"got {x_contracting_dims=} and {k_contracting_dims=}"
) )
scaling_mode = quantizer_set.x.scaling_mode
if scaling_mode.is_tensor_scaling():
k_contracting_dims = (0,)
elif scaling_mode.is_1d_block_scaling():
k_contracting_dims = (1,)
else:
raise ValueError(f"Unsupported scaling mode {scaling_mode.value} for grouped_dense")
casted_x = tex.grouped_quantize( casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
...@@ -306,11 +302,10 @@ def _grouped_dense_fwd_rule( ...@@ -306,11 +302,10 @@ def _grouped_dense_fwd_rule(
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K) # rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K) # colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor() grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_colwise_tensor() grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm( output = tex.grouped_gemm(
grouped_gemm_x, grouped_gemm_x,
...@@ -388,7 +383,7 @@ def _grouped_dense_bwd_rule( ...@@ -388,7 +383,7 @@ def _grouped_dense_bwd_rule(
g_contracting_dim = (1,) g_contracting_dim = (1,)
k_contracting_dim = (2,) k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor() dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
dgrad_kernel_T = ctx_kernel dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
...@@ -398,7 +393,7 @@ def _grouped_dense_bwd_rule( ...@@ -398,7 +393,7 @@ def _grouped_dense_bwd_rule(
x_contracting_dim = (0,) x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor() wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
dgrad = tex.grouped_gemm( dgrad = tex.grouped_gemm(
dgrad_grad, dgrad_grad,
......
...@@ -21,6 +21,7 @@ from .quantize import ( ...@@ -21,6 +21,7 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage,
) )
...@@ -198,8 +199,8 @@ def _layernorm_dense_fwd_rule( ...@@ -198,8 +199,8 @@ def _layernorm_dense_fwd_rule(
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
output = tex.gemm( output = tex.gemm(
casted_ln_out.get_rowwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_colwise_tensor(), casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
...@@ -209,8 +210,8 @@ def _layernorm_dense_fwd_rule( ...@@ -209,8 +210,8 @@ def _layernorm_dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
x.shape, x.shape,
kernel.shape, kernel.shape,
mu, mu,
...@@ -250,8 +251,8 @@ def _layernorm_dense_bwd_rule( ...@@ -250,8 +251,8 @@ def _layernorm_dense_bwd_rule(
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
( (
colwise_casted_ln_out, casted_ln_out,
rowwise_casted_kernel, casted_kernel,
x_shape, x_shape,
kernel_shape, kernel_shape,
mu, mu,
...@@ -281,8 +282,8 @@ def _layernorm_dense_bwd_rule( ...@@ -281,8 +282,8 @@ def _layernorm_dense_bwd_rule(
# NT GEMM # NT GEMM
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel, casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_constracting_dim, k_constracting_dim),
) )
...@@ -294,8 +295,8 @@ def _layernorm_dense_bwd_rule( ...@@ -294,8 +295,8 @@ def _layernorm_dense_bwd_rule(
# TN GEMM # TN GEMM
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_ln_out, casted_ln_out,
casted_grad.get_colwise_tensor(), casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim), (x_constracting_dim, g_constracting_dim),
) )
......
...@@ -22,7 +22,12 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -22,7 +22,12 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .quantize import (
with_sharding_constraint_by_logical_axes,
QuantizerSet,
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes from .sharding import get_non_contracting_logical_axes
...@@ -270,8 +275,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -270,8 +275,8 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
casted_ln_out.get_rowwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_colwise_tensor(), casted_kernel_1.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
...@@ -299,8 +304,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -299,8 +304,8 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = tex.gemm( dot_2_output = tex.gemm(
casted_act_out.get_rowwise_tensor(), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_colwise_tensor(), casted_kernel_2.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
...@@ -317,11 +322,11 @@ def _layernorm_mlp_fwd_rule( ...@@ -317,11 +322,11 @@ def _layernorm_mlp_fwd_rule(
rsigma, rsigma,
gamma, gamma,
beta, beta,
casted_ln_out.get_colwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_rowwise_tensor(), casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
dot_1_output, dot_1_output,
casted_act_out.get_colwise_tensor(), casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_rowwise_tensor(), casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
x_contracting_dims, x_contracting_dims,
k_contracting_dims, k_contracting_dims,
kernel_1.shape, kernel_1.shape,
...@@ -369,11 +374,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -369,11 +374,11 @@ def _layernorm_mlp_bwd_rule(
rsigma, rsigma,
gamma, gamma,
beta, beta,
colwise_casted_ln_out, casted_ln_out,
rowwise_casted_kernel_1, casted_kernel_1,
dot_1_output, dot_1_output,
colwise_casted_act_out, casted_act_out,
rowwise_casted_kernel_2, casted_kernel_2,
x_contracting_dims_in_fwd, x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd, k_contracting_dims_in_fwd,
kernel_1_shape, kernel_1_shape,
...@@ -404,8 +409,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -404,8 +409,8 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM # NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel_2, casted_kernel_2,
(g_contracting_dims_2, k_contracting_dims_2), (g_contracting_dims_2, k_contracting_dims_2),
) )
...@@ -418,8 +423,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -418,8 +423,8 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM # TN GEMM
# (hidden, batch...,) x (hidden, batch...) # (hidden, batch...,) x (hidden, batch...)
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
colwise_casted_act_out, casted_act_out,
casted_grad.get_colwise_tensor(), casted_grad.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), (x_contracting_dims, g_contracting_dims),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -433,7 +438,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -433,7 +438,7 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
g_contracting_dims_1 = tuple( g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
) )
...@@ -444,8 +449,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -444,8 +449,8 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM # NT GEMM
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(), casted_dact_out.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel_1, casted_kernel_1,
(g_contracting_dims_1, k_contracting_dims_1), (g_contracting_dims_1, k_contracting_dims_1),
) )
...@@ -454,8 +459,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -454,8 +459,8 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM # TN GEMM
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
colwise_casted_ln_out, casted_ln_out,
casted_dact_out.get_colwise_tensor(), casted_dact_out.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), (x_contracting_dims, g_contracting_dims),
) )
......
...@@ -15,3 +15,4 @@ from .dequantizer import * ...@@ -15,3 +15,4 @@ from .dequantizer import *
from .scaling_modes import * from .scaling_modes import *
from .metadata import * from .metadata import *
from .helper import * from .helper import *
from .device_utils import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Device utility functions for JAX quantization.
This module provides utility functions for checking device capabilities and compatibility
for quantization operations in JAX.
"""
import functools
import transformer_engine_jax
__all__ = [
"get_device_compute_capability",
"is_fp8_gemm_with_all_layouts_supported",
]
@functools.lru_cache(maxsize=None)
def get_device_compute_capability(gpu_id: int = 0) -> int:
"""
Get the compute capability of the device.
"""
return transformer_engine_jax.get_device_compute_capability(gpu_id)
@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
return 100 <= compute_capability < 120
...@@ -15,17 +15,13 @@ import jax ...@@ -15,17 +15,13 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"QuantizeConfig", "QuantizeConfig",
...@@ -203,7 +199,7 @@ class QuantizeConfig: ...@@ -203,7 +199,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
...@@ -218,7 +214,7 @@ class QuantizeConfig: ...@@ -218,7 +214,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling # DelayedScaling
...@@ -246,7 +242,6 @@ class QuantizeConfig: ...@@ -246,7 +242,6 @@ class QuantizeConfig:
cls.FP8_FORMAT = fp8_recipe.fp8_format cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod @classmethod
def finalize(cls) -> None: def finalize(cls) -> None:
...@@ -260,7 +255,7 @@ class QuantizeConfig: ...@@ -260,7 +255,7 @@ class QuantizeConfig:
cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False cls.INFERENCE_MODE = False
# DelayedScaling # DelayedScaling
cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
......
...@@ -23,6 +23,7 @@ from .helper import ( ...@@ -23,6 +23,7 @@ from .helper import (
QuantizeConfig, QuantizeConfig,
AmaxComputeAlgo, AmaxComputeAlgo,
) )
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = [ __all__ = [
"QuantizeLayout", "QuantizeLayout",
...@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer): ...@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer):
def __post_init__(self): def __post_init__(self):
if self.quantizers[0] is None: if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create( quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
) )
self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
self.data_layout = self.quantizers[0].data_layout self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list( def _create_grouped_tensor_from_tensor_list(
...@@ -841,8 +843,10 @@ class QuantizerFactory: ...@@ -841,8 +843,10 @@ class QuantizerFactory:
if is_2x2x: if is_2x2x:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_layout_x = QuantizeLayout.ROWWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
q_layout_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
...@@ -898,7 +902,15 @@ class QuantizerFactory: ...@@ -898,7 +902,15 @@ class QuantizerFactory:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = [] q_set = []
for _ in range(n_quantizer_sets): for _ in range(n_quantizer_sets):
...@@ -911,4 +923,4 @@ class QuantizerFactory: ...@@ -911,4 +923,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set) return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)
...@@ -13,7 +13,7 @@ from abc import ABC, abstractmethod ...@@ -13,7 +13,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import reduce from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
...@@ -21,10 +21,44 @@ from jax.experimental.custom_partitioning import CompoundFactor ...@@ -21,10 +21,44 @@ from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = ["QuantizeShardyRules", "ScalingMode"] __all__ = [
"QuantizeShardyRules",
"ScalingMode",
"TensorUsage",
]
class TensorUsage(Enum):
"""Enum indicating tensor usage in GEMM operations.
Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
The tensor usage can be:
- LHS: A is in the normal form
- LHS_TRANS: A is in the transposed form
- RHS: B is in the normal form
- RHS_TRANS: B is in the transposed form
The tensor usage is used in the ScaledTensor.get_tensor() method.
"""
# LHS: Left-hand side, RHS: Right-hand side
# LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
LHS = 0
LHS_TRANS = 1
RHS = 2
RHS_TRANS = 3
def __eq__(self, other):
if not isinstance(other, TensorUsage):
return False
return self.value == other.value
def __hash__(self):
return hash(self.value)
def DIVUP(a, b): def DIVUP(a, b):
...@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC): ...@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors The shape for scale tensors
""" """
@lru_cache(maxsize=4)
@abstractmethod
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
@abstractmethod @abstractmethod
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self, input_rank, unique_var, flatten_axis
...@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (0,) return (0,)
return (1,) return (1,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
if is_fp8_gemm_with_all_layouts_supported():
return QuantizeLayout.ROWWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape( def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -321,6 +384,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -321,6 +384,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape) return (*first_dim_scale_shape, *last_dim_scale_shape)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
# return QuantizeLayout.COLWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape( def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -506,6 +590,17 @@ class ScalingMode(Enum): ...@@ -506,6 +590,17 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1 self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
......
...@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class ...@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import ( from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
) )
__all__ = [ __all__ = [
"TensorUsage",
"ScaledTensor", "ScaledTensor",
"ScaledTensor1x", "ScaledTensor1x",
"ScaledTensor2x", "ScaledTensor2x",
...@@ -64,25 +65,15 @@ class ScaledTensor(ABC): ...@@ -64,25 +65,15 @@ class ScaledTensor(ABC):
""" """
@abstractmethod @abstractmethod
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the row-wise component of the tensor. """Returns the appropriate tensor based on the tensor usage and the scaling mode.
If the tensor usage is not valid for the scaling mode, an error is raised.
Returns: Args:
The row-wise tensor component usage: The usage of the tensor
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Returns: Returns:
The column-wise tensor component The tensor based on the usage
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
""" """
@abstractmethod @abstractmethod
...@@ -181,33 +172,19 @@ class ScaledTensor1x(ScaledTensor): ...@@ -181,33 +172,19 @@ class ScaledTensor1x(ScaledTensor):
""" """
return self._dq_func(self) return self._dq_func(self)
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the tensor if it's row-wise quantized. """Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage)
colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
Returns: if colwise_usage_valid or rowwise_usage_valid:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
return self return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!") raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
def get_colwise_tensor(self): f" self.is_colwise={self.is_colwise}!"
"""Returns the tensor if it's column-wise quantized. )
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names. """Applies sharding constraints to a tensor based on logical axis names.
...@@ -378,22 +355,22 @@ class ScaledTensor2x(ScaledTensor): ...@@ -378,22 +355,22 @@ class ScaledTensor2x(ScaledTensor):
""" """
return self.rowwise_tensor.dequantize() return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the row-wise quantized component. """Returns the tensor based on the tensor usage."""
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
Returns: if q_layout_rowwise == QuantizeLayout.ROWWISE:
The row-wise tensor component
"""
return self.rowwise_tensor return self.rowwise_tensor
def get_colwise_tensor(self): if q_layout_colwise == QuantizeLayout.COLWISE:
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
return self.colwise_tensor return self.colwise_tensor
raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names. """Applies sharding constraints to a tensor based on logical axis names.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment