Unverified Commit b362a6e0 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Removing NVTE_NO_SCALING (#1650)



* rm no scaling enum
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* update jax enum
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c84d1708
......@@ -86,8 +86,7 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
NVTE_INVALID_SCALING = 4,
NVTE_NO_SCALING = 5
NVTE_INVALID_SCALING = 100
};
/*! \brief TE Tensor type
......
......@@ -491,6 +491,11 @@ def grouped_gemm(
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
......
......@@ -90,7 +90,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr,
......
......@@ -233,8 +233,8 @@ class ScalingMode(Enum):
NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 4
NVTE_NO_SCALING = 5
NVTE_INVALID_SCALING = 100
NVTE_NO_SCALING = 1000
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment