# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
A wrapper function to flax.linen.with_logical_constraint.
DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
If logical_axis_names = None, this means no sharding constraint is applied.