Unverified Commit b362a6e0 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Removing NVTE_NO_SCALING (#1650)



* rm no scaling enum
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* update jax enum
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c84d1708
...@@ -86,8 +86,7 @@ enum NVTEScalingMode { ...@@ -86,8 +86,7 @@ enum NVTEScalingMode {
*/ */
NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3, NVTE_BLOCK_SCALING_2D = 3,
NVTE_INVALID_SCALING = 4, NVTE_INVALID_SCALING = 100
NVTE_NO_SCALING = 5
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
......
...@@ -491,6 +491,11 @@ def grouped_gemm( ...@@ -491,6 +491,11 @@ def grouped_gemm(
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32) dim_list = jnp.array(dims, dtype=jnp.int32)
# TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs # Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind( out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig, lhs_contig,
......
...@@ -90,7 +90,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -90,7 +90,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1}; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1}; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr, auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr)); nullptr, reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr, auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr,
......
...@@ -233,8 +233,8 @@ class ScalingMode(Enum): ...@@ -233,8 +233,8 @@ class ScalingMode(Enum):
NVTE_DELAYED_TENSOR_SCALING = 0 NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1 NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 4 NVTE_INVALID_SCALING = 100
NVTE_NO_SCALING = 5 NVTE_NO_SCALING = 1000
def _get_impl(self) -> ScalingModeMetadataImpl: def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode. """Get the implementation for this scaling mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment